|
|
@@ -68,10 +68,9 @@ class AlignedSimOTA(object):
|
|
|
) # [N, Mp]
|
|
|
|
|
|
(
|
|
|
- num_fg,
|
|
|
- gt_matched_classes, # [num_fg,]
|
|
|
- pred_ious_this_matching, # [num_fg,]
|
|
|
- matched_gt_inds, # [num_fg,]
|
|
|
+ assigned_labels, # [num_fg,]
|
|
|
+ assigned_ious, # [num_fg,]
|
|
|
+ assigned_indexs, # [num_fg,]
|
|
|
) = self.dynamic_k_matching(
|
|
|
cost_matrix,
|
|
|
pair_wise_ious,
|
|
|
@@ -81,13 +80,7 @@ class AlignedSimOTA(object):
|
|
|
)
|
|
|
del cls_cost, cost_matrix, pair_wise_ious, reg_cost
|
|
|
|
|
|
- return (
|
|
|
- gt_matched_classes,
|
|
|
- fg_mask,
|
|
|
- pred_ious_this_matching,
|
|
|
- matched_gt_inds,
|
|
|
- num_fg,
|
|
|
- )
|
|
|
+ return fg_mask, assigned_labels, assigned_ious, assigned_indexs
|
|
|
|
|
|
|
|
|
def get_in_boxes_info(
|
|
|
@@ -182,15 +175,14 @@ class AlignedSimOTA(object):
|
|
|
matching_matrix[:, anchor_matching_gt > 1] *= 0
|
|
|
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
|
|
|
fg_mask_inboxes = matching_matrix.sum(0) > 0
|
|
|
- num_fg = fg_mask_inboxes.sum().item()
|
|
|
|
|
|
fg_mask[fg_mask.clone()] = fg_mask_inboxes
|
|
|
|
|
|
- matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
|
|
|
- gt_matched_classes = gt_classes[matched_gt_inds]
|
|
|
+ assigned_indexs = matching_matrix[:, fg_mask_inboxes].argmax(0)
|
|
|
+ assigned_labels = gt_classes[assigned_indexs]
|
|
|
|
|
|
- pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
|
|
|
+ assigned_ious = (matching_matrix * pair_wise_ious).sum(0)[
|
|
|
fg_mask_inboxes
|
|
|
]
|
|
|
- return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
|
|
|
+ return assigned_labels, assigned_ious, assigned_indexs
|
|
|
|