yjh0410 2 gadi atpakaļ
vecāks
revīzija
ce7e4a4bfc
2 mainītis faili ar 14 papildinājumiem un 23 dzēšanām
  1. 6 7
      models/detectors/yolovx/loss.py
  2. 8 16
      models/detectors/yolovx/matcher.py

+ 6 - 7
models/detectors/yolovx/loss.py

@@ -89,11 +89,10 @@ class Criterion(object):
                 fg_mask = obj_preds.new_zeros(num_anchors).bool()
             else:
                 (
-                    gt_matched_classes,
                     fg_mask,
-                    pred_ious_this_matching,
-                    matched_gt_inds,
-                    num_fg_img,
+                    assigned_labels,
+                    assigned_ious,
+                    assigned_indexs
                 ) = self.matcher(
                     fpn_strides = fpn_strides,
                     anchors = anchors,
@@ -105,9 +104,9 @@ class Criterion(object):
                     )
 
                 obj_target = fg_mask.unsqueeze(-1)
-                cls_target = F.one_hot(gt_matched_classes.long(), self.num_classes)
-                cls_target = cls_target * pred_ious_this_matching.unsqueeze(-1)
-                box_target = tgt_bboxes[matched_gt_inds]
+                cls_target = F.one_hot(assigned_labels.long(), self.num_classes)
+                cls_target = cls_target * assigned_ious.unsqueeze(-1)
+                box_target = tgt_bboxes[assigned_indexs]
 
             cls_targets.append(cls_target)
             box_targets.append(box_target)

+ 8 - 16
models/detectors/yolovx/matcher.py

@@ -68,10 +68,9 @@ class AlignedSimOTA(object):
         ) # [N, Mp]
 
         (
-            num_fg,
-            gt_matched_classes,         # [num_fg,]
-            pred_ious_this_matching,    # [num_fg,]
-            matched_gt_inds,            # [num_fg,]
+            assigned_labels,         # [num_fg,]
+            assigned_ious,           # [num_fg,]
+            assigned_indexs,         # [num_fg,]
         ) = self.dynamic_k_matching(
             cost_matrix,
             pair_wise_ious,
@@ -81,13 +80,7 @@ class AlignedSimOTA(object):
             )
         del cls_cost, cost_matrix, pair_wise_ious, reg_cost
 
-        return (
-                gt_matched_classes,
-                fg_mask,
-                pred_ious_this_matching,
-                matched_gt_inds,
-                num_fg,
-        )
+        return fg_mask, assigned_labels, assigned_ious, assigned_indexs
 
 
     def get_in_boxes_info(
@@ -182,15 +175,14 @@ class AlignedSimOTA(object):
             matching_matrix[:, anchor_matching_gt > 1] *= 0
             matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
         fg_mask_inboxes = matching_matrix.sum(0) > 0
-        num_fg = fg_mask_inboxes.sum().item()
 
         fg_mask[fg_mask.clone()] = fg_mask_inboxes
 
-        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
-        gt_matched_classes = gt_classes[matched_gt_inds]
+        assigned_indexs = matching_matrix[:, fg_mask_inboxes].argmax(0)
+        assigned_labels = gt_classes[assigned_indexs]
 
-        pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
+        assigned_ious = (matching_matrix * pair_wise_ious).sum(0)[
             fg_mask_inboxes
         ]
-        return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
+        return assigned_labels, assigned_ious, assigned_indexs