matcher.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # ---------------------------------------------------------------------
  2. # Copyright (c) Megvii Inc. All rights reserved.
  3. # ---------------------------------------------------------------------
  4. import torch
  5. import torch.nn.functional as F
  6. from utils.box_ops import *
  7. # YOLOX SimOTA
  8. class SimOTA(object):
  9. """
  10. This code referenced to https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/models/yolo_head.py
  11. """
  12. def __init__(self, num_classes, center_sampling_radius, topk_candidate ):
  13. self.num_classes = num_classes
  14. self.center_sampling_radius = center_sampling_radius
  15. self.topk_candidate = topk_candidate
  16. @torch.no_grad()
  17. def __call__(self,
  18. fpn_strides,
  19. anchors,
  20. pred_obj,
  21. pred_cls,
  22. pred_box,
  23. tgt_labels,
  24. tgt_bboxes):
  25. # [M,]
  26. strides = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
  27. for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
  28. # List[F, M, 2] -> [M, 2]
  29. anchors = torch.cat(anchors, dim=0)
  30. num_anchor = anchors.shape[0]
  31. num_gt = len(tgt_labels)
  32. fg_mask, is_in_boxes_and_center = \
  33. self.get_in_boxes_info(
  34. tgt_bboxes,
  35. anchors,
  36. strides,
  37. num_anchor,
  38. num_gt
  39. )
  40. obj_preds_ = pred_obj[fg_mask] # [Mp, 1]
  41. cls_preds_ = pred_cls[fg_mask] # [Mp, C]
  42. box_preds_ = pred_box[fg_mask] # [Mp, 4]
  43. num_in_boxes_anchor = box_preds_.shape[0]
  44. # [N, Mp]
  45. pair_wise_ious, _ = box_iou(tgt_bboxes, box_preds_)
  46. pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
  47. # [N, C] -> [N, Mp, C]
  48. gt_cls = (
  49. F.one_hot(tgt_labels.long(), self.num_classes)
  50. .float()
  51. .unsqueeze(1)
  52. .repeat(1, num_in_boxes_anchor, 1)
  53. )
  54. with torch.cuda.amp.autocast(enabled=False):
  55. score_preds_ = torch.sqrt(
  56. cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
  57. * obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
  58. ) # [N, Mp, C]
  59. pair_wise_cls_loss = F.binary_cross_entropy(
  60. score_preds_, gt_cls, reduction="none"
  61. ).sum(-1) # [N, Mp]
  62. del score_preds_
  63. cost = (
  64. pair_wise_cls_loss
  65. + 3.0 * pair_wise_ious_loss
  66. + 100000.0 * (~is_in_boxes_and_center)
  67. ) # [N, Mp]
  68. (
  69. num_fg,
  70. gt_matched_classes, # [num_fg,]
  71. pred_ious_this_matching, # [num_fg,]
  72. matched_gt_inds, # [num_fg,]
  73. ) = self.dynamic_k_matching(
  74. cost,
  75. pair_wise_ious,
  76. tgt_labels,
  77. num_gt,
  78. fg_mask
  79. )
  80. del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
  81. return (
  82. gt_matched_classes,
  83. fg_mask,
  84. pred_ious_this_matching,
  85. matched_gt_inds,
  86. num_fg,
  87. )
  88. def get_in_boxes_info(
  89. self,
  90. gt_bboxes, # [N, 4]
  91. anchors, # [M, 2]
  92. strides, # [M,]
  93. num_anchors, # M
  94. num_gt, # N
  95. ):
  96. # anchor center
  97. x_centers = anchors[:, 0]
  98. y_centers = anchors[:, 1]
  99. # [M,] -> [1, M] -> [N, M]
  100. x_centers = x_centers.unsqueeze(0).repeat(num_gt, 1)
  101. y_centers = y_centers.unsqueeze(0).repeat(num_gt, 1)
  102. # [N,] -> [N, 1] -> [N, M]
  103. gt_bboxes_l = gt_bboxes[:, 0].unsqueeze(1).repeat(1, num_anchors) # x1
  104. gt_bboxes_t = gt_bboxes[:, 1].unsqueeze(1).repeat(1, num_anchors) # y1
  105. gt_bboxes_r = gt_bboxes[:, 2].unsqueeze(1).repeat(1, num_anchors) # x2
  106. gt_bboxes_b = gt_bboxes[:, 3].unsqueeze(1).repeat(1, num_anchors) # y2
  107. b_l = x_centers - gt_bboxes_l
  108. b_r = gt_bboxes_r - x_centers
  109. b_t = y_centers - gt_bboxes_t
  110. b_b = gt_bboxes_b - y_centers
  111. bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
  112. is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
  113. is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
  114. # in fixed center
  115. center_radius = self.center_sampling_radius
  116. # [N, 2]
  117. gt_centers = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) * 0.5
  118. # [1, M]
  119. center_radius_ = center_radius * strides.unsqueeze(0)
  120. gt_bboxes_l = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # x1
  121. gt_bboxes_t = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # y1
  122. gt_bboxes_r = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # x2
  123. gt_bboxes_b = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # y2
  124. c_l = x_centers - gt_bboxes_l
  125. c_r = gt_bboxes_r - x_centers
  126. c_t = y_centers - gt_bboxes_t
  127. c_b = gt_bboxes_b - y_centers
  128. center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
  129. is_in_centers = center_deltas.min(dim=-1).values > 0.0
  130. is_in_centers_all = is_in_centers.sum(dim=0) > 0
  131. # in boxes and in centers
  132. is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
  133. is_in_boxes_and_center = (
  134. is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
  135. )
  136. return is_in_boxes_anchor, is_in_boxes_and_center
  137. def dynamic_k_matching(
  138. self,
  139. cost,
  140. pair_wise_ious,
  141. gt_classes,
  142. num_gt,
  143. fg_mask
  144. ):
  145. # Dynamic K
  146. # ---------------------------------------------------------------
  147. matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
  148. ious_in_boxes_matrix = pair_wise_ious
  149. n_candidate_k = min(self.topk_candidate, ious_in_boxes_matrix.size(1))
  150. topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
  151. dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
  152. dynamic_ks = dynamic_ks.tolist()
  153. for gt_idx in range(num_gt):
  154. _, pos_idx = torch.topk(
  155. cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
  156. )
  157. matching_matrix[gt_idx][pos_idx] = 1
  158. del topk_ious, dynamic_ks, pos_idx
  159. anchor_matching_gt = matching_matrix.sum(0)
  160. if (anchor_matching_gt > 1).sum() > 0:
  161. _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
  162. matching_matrix[:, anchor_matching_gt > 1] *= 0
  163. matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
  164. fg_mask_inboxes = matching_matrix.sum(0) > 0
  165. num_fg = fg_mask_inboxes.sum().item()
  166. fg_mask[fg_mask.clone()] = fg_mask_inboxes
  167. matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
  168. gt_matched_classes = gt_classes[matched_gt_inds]
  169. pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
  170. fg_mask_inboxes
  171. ]
  172. return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds