yjh0410 пре 1 година
родитељ
комит
9c22313c65
2 измењених фајлова са 3 додато и 14 уклоњено
  1. 1 12
      models/detectors/rtdetr/rtdetr.py
  2. 2 2
      models/detectors/rtdetr/rtdetr_decoder.py

+ 1 - 12
models/detectors/rtdetr/rtdetr.py

@@ -98,22 +98,11 @@ class RT_DETR(nn.Module):
             box_preds = pred_boxes[-1]
             cls_preds = pred_logits[-1]
             
-            # TODO: post-process
+            # post-process
             bboxes, scores, labels = self.post_process(box_preds, cls_preds)
 
             return bboxes, scores, labels
         
-        # ----------- Head -----------
-        outputs = self.detect_head(pred_boxes, pred_logits, enc_topk_bboxes, enc_topk_logits, dn_meta, targets)
-
-        if self.training:
-            outputs_dict = outputs
-        else:
-            pred_boxes, pred_logits = outputs[0], outputs[1]
-            return pred_boxes, pred_logits
-            
-        return outputs_dict
-
 
 if __name__ == '__main__':
     import time

+ 2 - 2
models/detectors/rtdetr/rtdetr_decoder.py

@@ -245,14 +245,14 @@ class RTDETRTransformer(nn.Module):
 
         return target, reference_points_unact, enc_topk_bboxes, enc_topk_logits
     
-    def forward(self, feats, gt_meta=None):
+    def forward(self, feats, targets=None):
         # input projection and embedding
         memory, spatial_shapes, _ = self.get_encoder_input(feats)
 
         # prepare denoising training
         if self.training:
             denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
-                get_contrastive_denoising_training_group(gt_meta,
+                get_contrastive_denoising_training_group(targets,
                                                          self.num_classes,
                                                          self.num_queries,
                                                          self.denoising_class_embed.weight,