فهرست منبع

modify engine

yjh0410 2 سال پیش
والد
کامیت
60aedf5a8f
1فایلهای تغییر یافته به همراه7 افزوده شده و 7 حذف شده
  1. 7 7
      models/yolov1/yolov1.py

+ 7 - 7
models/yolov1/yolov1.py

@@ -89,15 +89,15 @@ class YOLOv1(nn.Module):
         grid_cell = self.create_grid(fmp_size)
 
         # 计算预测边界框的中心点坐标和宽高
-        pred[..., :2] = (torch.sigmoid(pred[..., :2]) + grid_cell) * self.stride
-        pred[..., 2:] = torch.exp(pred[..., 2:])
+        pred_ctr = (torch.sigmoid(pred[..., :2]) + grid_cell) * self.stride
+        pred_wh = torch.exp(pred[..., 2:])
 
         # 将所有bbox的中心带你坐标和宽高换算成x1y1x2y2形式
-        output = torch.zeros_like(pred)
-        output[..., :2] = pred[..., :2] - pred[..., 2:] * 0.5
-        output[..., 2:] = pred[..., :2] + pred[..., 2:] * 0.5
-        
-        return output
+        pred_x1y1 = pred_ctr - pred_wh * 0.5
+        pred_x2y2 = pred_ctr + pred_wh * 0.5
+        pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+        return pred_box
 
 
     def postprocess(self, bboxes, scores):