浏览代码

fix a bug in num_gts

yjh0410 1 年之前
父节点
当前提交
0040d028b4
共有 2 个文件被更改,包括 32 次插入2 次删除
  1. 30 0
      engine.py
  2. 2 2
      models/detectors/rtdetr/matcher.py

+ 30 - 0
engine.py

@@ -1303,6 +1303,10 @@ class RTRTrainer(object):
                     images, targets, self.model_cfg['out_stride'][-1], self.args.min_box_size, self.model_cfg['multi_scale'])
             else:
                 targets = self.refine_targets(img_size, targets, self.args.min_box_size)
+
+            # xyxy -> cxcybwbh
+            targets = self.box_xyxy_to_cxcywh(targets)
+            print(targets)
                 
             # Visualize train targets
             if self.args.vis_tgt:
@@ -1413,6 +1417,32 @@ class RTRTrainer(object):
 
         return images, targets, new_img_size
 
+    def box_xyxy_to_cxcywh(self, targets):
+        # rescale targets
+        for tgt in targets:
+            boxes_xyxy = tgt["boxes"].clone()
+            # rescale box
+            cxcy = (boxes_xyxy[..., :2] + boxes_xyxy[..., 2:]) * 0.5
+            bwbh = boxes_xyxy[..., 2:] - boxes_xyxy[..., :2]
+            boxes_bwbh = torch.cat([cxcy, bwbh], dim=-1)
+
+            tgt["boxes"] = boxes_bwbh
+
+        return targets
+
+    def box_cxcywh_to_xyxy(self, targets):
+        # rescale targets
+        for tgt in targets:
+            boxes_cxcywh = tgt["boxes"].clone()
+            # rescale box
+            x1y1 = (boxes_cxcywh[..., :2] + boxes_cxcywh[..., 2:]) * 0.5
+            bwbh = boxes_cxcywh[..., 2:] - boxes_cxcywh[..., :2]
+            boxes_bwbh = torch.cat([boxes_cxcywh, bwbh], dim=-1)
+
+            tgt["boxes"] = boxes_bwbh
+
+        return targets
+
     def check_second_stage(self):
         # set second stage
         print('============== Second stage of Training ==============')

+ 2 - 2
models/detectors/rtdetr/matcher.py

@@ -36,9 +36,9 @@ class HungarianMatcher(nn.Module):
 
         # -------------------- Regression cost --------------------
         ## L1 cost: [Nq, M]
-        cost_bbox = torch.cdist(out_bbox, box_xyxy_to_cxcywh(tgt_bbox).to(out_bbox.device), p=1)
+        cost_bbox = torch.cdist(out_bbox, tgt_bbox.to(out_bbox.device), p=1)
         ## GIoU cost: Nq, M]
-        cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), tgt_bbox.to(out_bbox.device))
+        cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox).to(out_bbox.device))
 
         # Final cost: [B, Nq, M]
         C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou