matcher.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from utils.box_ops import box_iou, bbox_iou
  5. # -------------------------- Basic Functions --------------------------
  6. def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
  7. """select the positive anchors's center in gt
  8. Args:
  9. xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
  10. gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
  11. Return:
  12. (Tensor): shape(bs, n_max_boxes, num_total_anchors)
  13. """
  14. n_anchors = xy_centers.size(0)
  15. bs, n_max_boxes, _ = gt_bboxes.size()
  16. _gt_bboxes = gt_bboxes.reshape([-1, 4])
  17. xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
  18. gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
  19. gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
  20. b_lt = xy_centers - gt_bboxes_lt
  21. b_rb = gt_bboxes_rb - xy_centers
  22. bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
  23. bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
  24. return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
  25. def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
  26. """if an anchor box is assigned to multiple gts,
  27. the one with the highest iou will be selected.
  28. Args:
  29. mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
  30. overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
  31. Return:
  32. target_gt_idx (Tensor): shape(bs, num_total_anchors)
  33. fg_mask (Tensor): shape(bs, num_total_anchors)
  34. mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
  35. """
  36. fg_mask = mask_pos.sum(axis=-2)
  37. if fg_mask.max() > 1:
  38. mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1])
  39. max_overlaps_idx = overlaps.argmax(axis=1)
  40. is_max_overlaps = F.one_hot(max_overlaps_idx, n_max_boxes)
  41. is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype)
  42. mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos)
  43. fg_mask = mask_pos.sum(axis=-2)
  44. target_gt_idx = mask_pos.argmax(axis=-2)
  45. return target_gt_idx, fg_mask , mask_pos
  46. def iou_calculator(box1, box2, eps=1e-9):
  47. """Calculate iou for batch
  48. Args:
  49. box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
  50. box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
  51. Return:
  52. (Tensor): shape(bs, n_max_boxes, num_total_anchors)
  53. """
  54. box1 = box1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4]
  55. box2 = box2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4]
  56. px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
  57. gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
  58. x1y1 = torch.maximum(px1y1, gx1y1)
  59. x2y2 = torch.minimum(px2y2, gx2y2)
  60. overlap = (x2y2 - x1y1).clip(0).prod(-1)
  61. area1 = (px2y2 - px1y1).clip(0).prod(-1)
  62. area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
  63. union = area1 + area2 - overlap + eps
  64. return overlap / union
  65. # -------------------------- Task Aligned Assigner --------------------------
  66. class TaskAlignedAssigner(nn.Module):
  67. def __init__(self, topk=10, alpha=0.5, beta=6.0, eps=1e-9, num_classes=80):
  68. super(TaskAlignedAssigner, self).__init__()
  69. self.topk = topk
  70. self.num_classes = num_classes
  71. self.bg_idx = num_classes
  72. self.alpha = alpha
  73. self.beta = beta
  74. self.eps = eps
  75. @torch.no_grad()
  76. def forward(self,
  77. pd_scores,
  78. pd_bboxes,
  79. anc_points,
  80. gt_labels,
  81. gt_bboxes):
  82. """This code referenced to
  83. https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
  84. Args:
  85. pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
  86. pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
  87. anc_points (Tensor): shape(num_total_anchors, 2)
  88. gt_labels (Tensor): shape(bs, n_max_boxes, 1)
  89. gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
  90. Returns:
  91. target_labels (Tensor): shape(bs, num_total_anchors)
  92. target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
  93. target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
  94. fg_mask (Tensor): shape(bs, num_total_anchors)
  95. """
  96. self.bs = pd_scores.size(0)
  97. self.n_max_boxes = gt_bboxes.size(1)
  98. mask_pos, align_metric, overlaps = self.get_pos_mask(
  99. pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points)
  100. target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
  101. mask_pos, overlaps, self.n_max_boxes)
  102. # assigned target
  103. target_labels, target_bboxes, target_scores = self.get_targets(
  104. gt_labels, gt_bboxes, target_gt_idx, fg_mask)
  105. # normalize
  106. align_metric *= mask_pos
  107. pos_align_metrics = align_metric.amax(axis=-1, keepdim=True) # b, max_num_obj
  108. pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True) # b, max_num_obj
  109. norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
  110. target_scores = target_scores * norm_align_metric
  111. return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
  112. def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points):
  113. # get anchor_align metric, (b, max_num_obj, h*w)
  114. align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes)
  115. # get in_gts mask, (b, max_num_obj, h*w)
  116. mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
  117. # get topk_metric mask, (b, max_num_obj, h*w)
  118. mask_topk = self.select_topk_candidates(align_metric * mask_in_gts)
  119. # merge all mask to a final mask, (b, max_num_obj, h*w)
  120. mask_pos = mask_topk * mask_in_gts
  121. return mask_pos, align_metric, overlaps
  122. def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes):
  123. ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj
  124. ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes) # b, max_num_obj
  125. ind[1] = gt_labels.long().squeeze(-1) # b, max_num_obj
  126. # get the scores of each grid for each gt cls
  127. bbox_scores = pd_scores[ind[0], :, ind[1]] # b, max_num_obj, h*w
  128. overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False).squeeze(3).clamp(0)
  129. align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
  130. return align_metric, overlaps
  131. def select_topk_candidates(self, metrics, largest=True):
  132. """
  133. Args:
  134. metrics: (b, max_num_obj, h*w).
  135. topk_mask: (b, max_num_obj, topk) or None
  136. """
  137. num_anchors = metrics.shape[-1] # h*w
  138. # (b, max_num_obj, topk)
  139. topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
  140. topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).tile([1, 1, self.topk])
  141. # (b, max_num_obj, topk)
  142. topk_idxs[~topk_mask] = 0
  143. # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
  144. is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
  145. # filter invalid bboxes
  146. is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk)
  147. return is_in_topk.to(metrics.dtype)
  148. def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
  149. """
  150. Args:
  151. gt_labels: (b, max_num_obj, 1)
  152. gt_bboxes: (b, max_num_obj, 4)
  153. target_gt_idx: (b, h*w)
  154. fg_mask: (b, h*w)
  155. """
  156. # assigned target labels, (b, 1)
  157. batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
  158. target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
  159. target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
  160. # assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
  161. target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
  162. # assigned target scores
  163. target_labels.clamp(0)
  164. target_scores = F.one_hot(target_labels, self.num_classes) # (b, h*w, 80)
  165. fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
  166. target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
  167. return target_labels, target_bboxes, target_scores
  168. # -------------------------- Aligned SimOTA Assigner --------------------------
  169. class AlignedSimOTA(object):
  170. """
  171. This code referenced to https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/models/yolo_head.py
  172. """
  173. def __init__(self, num_classes, center_sampling_radius, topk_candidate ):
  174. self.num_classes = num_classes
  175. self.center_sampling_radius = center_sampling_radius
  176. self.topk_candidate = topk_candidate
  177. @torch.no_grad()
  178. def __call__(self,
  179. fpn_strides,
  180. anchors,
  181. pred_cls,
  182. pred_box,
  183. tgt_labels,
  184. tgt_bboxes):
  185. # [M,]
  186. strides_tensor = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
  187. for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
  188. # List[F, M, 2] -> [M, 2]
  189. anchors = torch.cat(anchors, dim=0)
  190. num_anchor = anchors.shape[0]
  191. num_gt = len(tgt_labels)
  192. # ----------------------- Find inside points -----------------------
  193. fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
  194. tgt_bboxes, anchors, strides_tensor, num_anchor, num_gt)
  195. cls_preds = pred_cls[fg_mask].float() # [Mp, C]
  196. box_preds = pred_box[fg_mask].float() # [Mp, 4]
  197. # ----------------------- Reg cost -----------------------
  198. pair_wise_ious, _ = box_iou(tgt_bboxes, box_preds) # [N, Mp]
  199. reg_cost = -torch.log(pair_wise_ious + 1e-8) # [N, Mp]
  200. # ----------------------- Cls cost -----------------------
  201. with torch.cuda.amp.autocast(enabled=False):
  202. # [Mp, C] -> [N, Mp, C]
  203. score_preds = cls_preds.sigmoid_().unsqueeze(0).repeat(num_gt, 1, 1)
  204. # prepare cls_target
  205. cls_targets = F.one_hot(tgt_labels.long(), self.num_classes).float()
  206. cls_targets = cls_targets.unsqueeze(1).repeat(1, score_preds.size(1), 1)
  207. # [N, Mp]
  208. cls_cost = F.binary_cross_entropy(score_preds, cls_targets, reduction="none").sum(-1)
  209. del score_preds
  210. #----------------------- Dynamic K-Matching -----------------------
  211. cost_matrix = (
  212. cls_cost
  213. + 3.0 * reg_cost
  214. + 100000.0 * (~is_in_boxes_and_center)
  215. ) # [N, Mp]
  216. (
  217. assigned_labels, # [num_fg,]
  218. assigned_ious, # [num_fg,]
  219. assigned_indexs, # [num_fg,]
  220. ) = self.dynamic_k_matching(
  221. cost_matrix,
  222. pair_wise_ious,
  223. tgt_labels,
  224. num_gt,
  225. fg_mask
  226. )
  227. del cls_cost, cost_matrix, pair_wise_ious, reg_cost
  228. return fg_mask, assigned_labels, assigned_ious, assigned_indexs
  229. def get_in_boxes_info(
  230. self,
  231. gt_bboxes, # [N, 4]
  232. anchors, # [M, 2]
  233. strides, # [M,]
  234. num_anchors, # M
  235. num_gt, # N
  236. ):
  237. # anchor center
  238. x_centers = anchors[:, 0]
  239. y_centers = anchors[:, 1]
  240. # [M,] -> [1, M] -> [N, M]
  241. x_centers = x_centers.unsqueeze(0).repeat(num_gt, 1)
  242. y_centers = y_centers.unsqueeze(0).repeat(num_gt, 1)
  243. # [N,] -> [N, 1] -> [N, M]
  244. gt_bboxes_l = gt_bboxes[:, 0].unsqueeze(1).repeat(1, num_anchors) # x1
  245. gt_bboxes_t = gt_bboxes[:, 1].unsqueeze(1).repeat(1, num_anchors) # y1
  246. gt_bboxes_r = gt_bboxes[:, 2].unsqueeze(1).repeat(1, num_anchors) # x2
  247. gt_bboxes_b = gt_bboxes[:, 3].unsqueeze(1).repeat(1, num_anchors) # y2
  248. b_l = x_centers - gt_bboxes_l
  249. b_r = gt_bboxes_r - x_centers
  250. b_t = y_centers - gt_bboxes_t
  251. b_b = gt_bboxes_b - y_centers
  252. bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
  253. is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
  254. is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
  255. # in fixed center
  256. center_radius = self.center_sampling_radius
  257. # [N, 2]
  258. gt_centers = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) * 0.5
  259. # [1, M]
  260. center_radius_ = center_radius * strides.unsqueeze(0)
  261. gt_bboxes_l = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # x1
  262. gt_bboxes_t = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # y1
  263. gt_bboxes_r = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # x2
  264. gt_bboxes_b = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # y2
  265. c_l = x_centers - gt_bboxes_l
  266. c_r = gt_bboxes_r - x_centers
  267. c_t = y_centers - gt_bboxes_t
  268. c_b = gt_bboxes_b - y_centers
  269. center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
  270. is_in_centers = center_deltas.min(dim=-1).values > 0.0
  271. is_in_centers_all = is_in_centers.sum(dim=0) > 0
  272. # in boxes and in centers
  273. is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
  274. is_in_boxes_and_center = (
  275. is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
  276. )
  277. return is_in_boxes_anchor, is_in_boxes_and_center
  278. def dynamic_k_matching(
  279. self,
  280. cost,
  281. pair_wise_ious,
  282. gt_classes,
  283. num_gt,
  284. fg_mask
  285. ):
  286. # Dynamic K
  287. # ---------------------------------------------------------------
  288. matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
  289. ious_in_boxes_matrix = pair_wise_ious
  290. n_candidate_k = min(self.topk_candidate, ious_in_boxes_matrix.size(1))
  291. topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
  292. dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
  293. dynamic_ks = dynamic_ks.tolist()
  294. for gt_idx in range(num_gt):
  295. _, pos_idx = torch.topk(
  296. cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
  297. )
  298. matching_matrix[gt_idx][pos_idx] = 1
  299. del topk_ious, dynamic_ks, pos_idx
  300. anchor_matching_gt = matching_matrix.sum(0)
  301. if (anchor_matching_gt > 1).sum() > 0:
  302. _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
  303. matching_matrix[:, anchor_matching_gt > 1] *= 0
  304. matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
  305. fg_mask_inboxes = matching_matrix.sum(0) > 0
  306. fg_mask[fg_mask.clone()] = fg_mask_inboxes
  307. assigned_indexs = matching_matrix[:, fg_mask_inboxes].argmax(0)
  308. assigned_labels = gt_classes[assigned_indexs]
  309. assigned_ious = (matching_matrix * pair_wise_ious).sum(0)[
  310. fg_mask_inboxes
  311. ]
  312. return assigned_labels, assigned_ious, assigned_indexs