yjh0410 2 年之前
父節點
當前提交
3f55ad49cf
共有 2 個文件被更改,包括 31 次插入30 次删除
  1. 6 12
      models/detectors/rtdetr/rtdetr.py
  2. 25 18
      models/detectors/rtdetr/rtdetr_dethead.py

+ 6 - 12
models/detectors/rtdetr/rtdetr.py

@@ -64,18 +64,7 @@ class RTDETR(nn.Module):
 
         # -------------------- DetHead --------------------
         out_logits, out_bbox = self.dethead(hs, reference, False)
-
-        # -------------------- Decode bbox --------------------
-        cls_pred = out_logits[0]
-        box_pred = out_bbox[0]
-        ## cxcywh -> xyxy
-        x1y1_pred = box_pred[..., :2] - box_pred[..., 2:] * 0.5
-        x2y2_pred = box_pred[..., :2] + box_pred[..., 2:] * 0.5
-        box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
-        ## denormalize bbox
-        img_h, img_w = x.shape[-2:]
-        box_pred[..., 0::2] *= img_w
-        box_pred[..., 1::2] *= img_h
+        cls_pred, box_pred = out_logits[0], out_bbox[0]
 
         # -------------------- Top-k --------------------
         cls_pred = cls_pred.flatten().sigmoid_()
@@ -87,6 +76,11 @@ class RTDETR(nn.Module):
         topk_labels = topk_idxs % self.num_classes
         topk_bboxes = box_pred[topk_box_idxs]
 
+        # denormalize bbox
+        img_h, img_w = x.shape[-2:]
+        box_pred[..., 0::2] *= img_w
+        box_pred[..., 1::2] *= img_h
+
         if self.deploy:
             return topk_bboxes, topk_scores, topk_labels
         else:

+ 25 - 18
models/detectors/rtdetr/rtdetr_dethead.py

@@ -37,38 +37,45 @@ class DetectHead(nn.Module):
             nn.init.constant_(bbox_embed.layers[-1].bias.data, 0)
         
 
+    def inverse_sigmoid(self, x):
+        x = x.clamp(min=0, max=1)
+        return torch.log(x.clamp(min=1e-5)/(1 - x).clamp(min=1e-5))
+
+
+    def decode_bbox(self, outputs_coords):
+        ## cxcywh -> xyxy
+        x1y1_pred = outputs_coords[..., :2] - outputs_coords[..., 2:] * 0.5
+        x2y2_pred = outputs_coords[..., :2] + outputs_coords[..., 2:] * 0.5
+        box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+        
+        return box_pred
+
+
     def forward(self, hs, reference, multi_layer=False):
         if multi_layer:
-            ## class embed
-            outputs_class = torch.stack([layer_cls_embed(layer_hs) for
-                                        layer_cls_embed, layer_hs in zip(self.class_embed, hs)])
-            ## Bbox embed
+            # class embed
+            outputs_class = torch.stack([
+                layer_cls_embed(layer_hs) for layer_cls_embed, layer_hs in zip(self.class_embed, hs)])
+            # Bbox embed
             outputs_coords = []
             for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)):
                 layer_delta_unsig = layer_bbox_embed(layer_hs)
-                # ---------- start <inverse sigmoid> ----------
-                layer_ref_sig = layer_ref_sig.clamp(min=0, max=1)
-                layer_ref_sig_1 = layer_ref_sig.clamp(min=1e-5)
-                layer_ref_sig_2 = (1 - layer_ref_sig).clamp(min=1e-5)
-                layer_ref_sig = torch.log(layer_ref_sig_1/layer_ref_sig_2)
-                # ---------- end <inverse sigmoid> ----------
+                layer_ref_sig = self.inverse_sigmoid(layer_ref_sig)
                 layer_outputs_unsig = layer_delta_unsig + layer_ref_sig
                 layer_outputs_unsig = layer_outputs_unsig.sigmoid()
                 outputs_coords.append(layer_outputs_unsig)
         else:
-            ## class embed
+            # class embed
             outputs_class = self.class_embed[-1](hs[-1]) 
-            ## bbox embed
+            # bbox embed
             delta_unsig = self.bbox_embed[-1](hs[-1])
             ref_sig = reference[-2]
-            ## ---------- start <inverse sigmoid> ----------
-            ref_sig = ref_sig.clamp(min=0, max=1)
-            ref_sig_1 = ref_sig.clamp(min=1e-5)
-            ref_sig_2 = (1 - ref_sig).clamp(min=1e-5)
-            ref_sig = torch.log(ref_sig_1/ref_sig_2)
-            ## ---------- end <inverse sigmoid> ----------
+            ref_sig = self.inverse_sigmoid(ref_sig)
             outputs_unsig = delta_unsig + ref_sig
             outputs_coords = outputs_unsig.sigmoid()
+            # decode bbox
+            outputs_coords = self.decode_bbox(outputs_coords)
+
 
         return outputs_class, outputs_coords