|
|
@@ -109,7 +109,7 @@ class YOLOv4(nn.Module):
|
|
|
anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
|
|
|
# [HW, 2] -> [HW, KA, 2] -> [M, 2]
|
|
|
anchor_xy = anchor_xy.unsqueeze(1).repeat(1, self.num_anchors, 1)
|
|
|
- anchor_xy = anchor_xy.view(-1, 2).to(self.device)
|
|
|
+ anchor_xy = anchor_xy.view(-1, 2).to(self.device) + 0.5
|
|
|
|
|
|
# [KA, 2] -> [1, KA, 2] -> [HW, KA, 2] -> [M, 2]
|
|
|
anchor_wh = anchor_size.unsqueeze(0).repeat(fmp_h*fmp_w, 1, 1)
|
|
|
@@ -126,7 +126,7 @@ class YOLOv4(nn.Module):
|
|
|
"""
|
|
|
|
|
|
# 计算预测边界框的中心点坐标和宽高
|
|
|
- pred_ctr = (torch.sigmoid(reg_pred[..., :2]) + anchors[..., :2]) * self.stride[level]
|
|
|
+ pred_ctr = (torch.sigmoid(reg_pred[..., :2]) * 3.0 - 1.5 + anchors[..., :2]) * self.stride[level]
|
|
|
pred_wh = torch.exp(reg_pred[..., 2:]) * anchors[..., 2:]
|
|
|
|
|
|
# 将所有bbox的中心带你坐标和宽高换算成x1y1x2y2形式
|