|
|
@@ -89,15 +89,15 @@ class YOLOv1(nn.Module):
|
|
|
grid_cell = self.create_grid(fmp_size)
|
|
|
|
|
|
# 计算预测边界框的中心点坐标和宽高
|
|
|
- pred[..., :2] = (torch.sigmoid(pred[..., :2]) + grid_cell) * self.stride
|
|
|
- pred[..., 2:] = torch.exp(pred[..., 2:])
|
|
|
+ pred_ctr = (torch.sigmoid(pred[..., :2]) + grid_cell) * self.stride
|
|
|
+ pred_wh = torch.exp(pred[..., 2:])
|
|
|
|
|
|
# 将所有bbox的中心带你坐标和宽高换算成x1y1x2y2形式
|
|
|
- output = torch.zeros_like(pred)
|
|
|
- output[..., :2] = pred[..., :2] - pred[..., 2:] * 0.5
|
|
|
- output[..., 2:] = pred[..., :2] + pred[..., 2:] * 0.5
|
|
|
-
|
|
|
- return output
|
|
|
+ pred_x1y1 = pred_ctr - pred_wh * 0.5
|
|
|
+ pred_x2y2 = pred_ctr + pred_wh * 0.5
|
|
|
+ pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
|
|
|
+
|
|
|
+ return pred_box
|
|
|
|
|
|
|
|
|
def postprocess(self, bboxes, scores):
|