فهرست منبع

modify matcher of YOLOv4

yjh0410 2 سال پیش
والد
کامیت
ae168b8e6f
1فایلهای تغییر یافته به همراه2 افزوده شده و 2 حذف شده
  1. 2 2
      models/yolov4/yolov4.py

+ 2 - 2
models/yolov4/yolov4.py

@@ -109,7 +109,7 @@ class YOLOv4(nn.Module):
         anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
         # [HW, 2] -> [HW, KA, 2] -> [M, 2]
         anchor_xy = anchor_xy.unsqueeze(1).repeat(1, self.num_anchors, 1)
-        anchor_xy = anchor_xy.view(-1, 2).to(self.device)
+        anchor_xy = anchor_xy.view(-1, 2).to(self.device) + 0.5
 
         # [KA, 2] -> [1, KA, 2] -> [HW, KA, 2] -> [M, 2]
         anchor_wh = anchor_size.unsqueeze(0).repeat(fmp_h*fmp_w, 1, 1)
@@ -126,7 +126,7 @@ class YOLOv4(nn.Module):
         """
 
         # 计算预测边界框的中心点坐标和宽高
-        pred_ctr = (torch.sigmoid(reg_pred[..., :2]) + anchors[..., :2]) * self.stride[level]
+        pred_ctr = (torch.sigmoid(reg_pred[..., :2]) * 3.0 - 1.5 + anchors[..., :2]) * self.stride[level]
         pred_wh = torch.exp(reg_pred[..., 2:]) * anchors[..., 2:]
 
         # 将所有bbox的中心带你坐标和宽高换算成x1y1x2y2形式