matcher.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. import math
  2. import torch
  3. import torch.nn.functional as F
  4. from utils.box_ops import *
  5. @torch.no_grad()
  6. def get_ious_and_iou_loss(inputs,
  7. targets,
  8. weight=None,
  9. box_mode="xyxy",
  10. loss_type="iou",
  11. reduction="none"):
  12. """
  13. Compute iou loss of type ['iou', 'giou', 'linear_iou']
  14. Args:
  15. inputs (tensor): pred values
  16. targets (tensor): target values
  17. weight (tensor): loss weight
  18. box_mode (str): 'xyxy' or 'ltrb', 'ltrb' is currently supported.
  19. loss_type (str): 'giou' or 'iou' or 'linear_iou'
  20. reduction (str): reduction manner
  21. Returns:
  22. loss (tensor): computed iou loss.
  23. """
  24. if box_mode == "ltrb":
  25. inputs = torch.cat((-inputs[..., :2], inputs[..., 2:]), dim=-1)
  26. targets = torch.cat((-targets[..., :2], targets[..., 2:]), dim=-1)
  27. elif box_mode != "xyxy":
  28. raise NotImplementedError
  29. eps = torch.finfo(torch.float32).eps
  30. inputs_area = (inputs[..., 2] - inputs[..., 0]).clamp_(min=0) \
  31. * (inputs[..., 3] - inputs[..., 1]).clamp_(min=0)
  32. targets_area = (targets[..., 2] - targets[..., 0]).clamp_(min=0) \
  33. * (targets[..., 3] - targets[..., 1]).clamp_(min=0)
  34. w_intersect = (torch.min(inputs[..., 2], targets[..., 2])
  35. - torch.max(inputs[..., 0], targets[..., 0])).clamp_(min=0)
  36. h_intersect = (torch.min(inputs[..., 3], targets[..., 3])
  37. - torch.max(inputs[..., 1], targets[..., 1])).clamp_(min=0)
  38. area_intersect = w_intersect * h_intersect
  39. area_union = targets_area + inputs_area - area_intersect
  40. ious = area_intersect / area_union.clamp(min=eps)
  41. if loss_type == "iou":
  42. loss = -ious.clamp(min=eps).log()
  43. elif loss_type == "linear_iou":
  44. loss = 1 - ious
  45. elif loss_type == "giou":
  46. g_w_intersect = torch.max(inputs[..., 2], targets[..., 2]) \
  47. - torch.min(inputs[..., 0], targets[..., 0])
  48. g_h_intersect = torch.max(inputs[..., 3], targets[..., 3]) \
  49. - torch.min(inputs[..., 1], targets[..., 1])
  50. ac_uion = g_w_intersect * g_h_intersect
  51. gious = ious - (ac_uion - area_union) / ac_uion.clamp(min=eps)
  52. loss = 1 - gious
  53. else:
  54. raise NotImplementedError
  55. if weight is not None:
  56. loss = loss * weight.view(loss.size())
  57. if reduction == "mean":
  58. loss = loss.sum() / max(weight.sum().item(), eps)
  59. else:
  60. if reduction == "mean":
  61. loss = loss.mean()
  62. if reduction == "sum":
  63. loss = loss.sum()
  64. return ious, loss
  65. class FcosMatcher(object):
  66. """
  67. This code referenced to https://github.com/Megvii-BaseDetection/cvpods
  68. """
  69. def __init__(self,
  70. num_classes,
  71. center_sampling_radius,
  72. object_sizes_of_interest,
  73. box_weights=[1, 1, 1, 1]):
  74. self.num_classes = num_classes
  75. self.center_sampling_radius = center_sampling_radius
  76. self.object_sizes_of_interest = object_sizes_of_interest
  77. self.box_weightss = box_weights
  78. def get_deltas(self, anchors, boxes):
  79. """
  80. Get box regression transformation deltas (dl, dt, dr, db) that can be used
  81. to transform the `anchors` into the `boxes`. That is, the relation
  82. ``boxes == self.apply_deltas(deltas, anchors)`` is true.
  83. Args:
  84. anchors (Tensor): anchors, e.g., feature map coordinates
  85. boxes (Tensor): target of the transformation, e.g., ground-truth
  86. boxes.
  87. """
  88. assert isinstance(anchors, torch.Tensor), type(anchors)
  89. assert isinstance(boxes, torch.Tensor), type(boxes)
  90. deltas = torch.cat((anchors - boxes[..., :2], boxes[..., 2:] - anchors),
  91. dim=-1) * anchors.new_tensor(self.box_weightss)
  92. return deltas
  93. @torch.no_grad()
  94. def __call__(self, fpn_strides, anchors, targets):
  95. """
  96. fpn_strides: (List) List[8, 16, 32, ...] stride of network output.
  97. anchors: (List of Tensor) List[F, M, 2], F = num_fpn_levels
  98. targets: (Dict) dict{'boxes': [...],
  99. 'labels': [...],
  100. 'orig_size': ...}
  101. """
  102. gt_classes = []
  103. gt_anchors_deltas = []
  104. gt_centerness = []
  105. device = anchors[0].device
  106. # List[F, M, 2] -> [M, 2]
  107. anchors_over_all_feature_maps = torch.cat(anchors, dim=0).to(device)
  108. for targets_per_image in targets:
  109. # generate object_sizes_of_interest: List[[M, 2]]
  110. object_sizes_of_interest = [anchors_i.new_tensor(scale_range).unsqueeze(0).expand(anchors_i.size(0), -1)
  111. for anchors_i, scale_range in zip(anchors, self.object_sizes_of_interest)]
  112. # List[F, M, 2] -> [M, 2], M = M1 + M2 + ... + MF
  113. object_sizes_of_interest = torch.cat(object_sizes_of_interest, dim=0)
  114. # [N, 4]
  115. tgt_box = targets_per_image['boxes'].to(device)
  116. # [N, C]
  117. tgt_cls = targets_per_image['labels'].to(device)
  118. # [N, M, 4], M = M1 + M2 + ... + MF
  119. deltas = self.get_deltas(anchors_over_all_feature_maps, tgt_box.unsqueeze(1))
  120. has_gt = (len(tgt_cls) > 0)
  121. if has_gt:
  122. if self.center_sampling_radius > 0:
  123. # bbox centers: [N, 2]
  124. centers = (tgt_box[..., :2] + tgt_box[..., 2:]) * 0.5
  125. is_in_boxes = []
  126. for stride, anchors_i in zip(fpn_strides, anchors):
  127. radius = stride * self.center_sampling_radius
  128. # [N, 4]
  129. center_boxes = torch.cat((
  130. torch.max(centers - radius, tgt_box[:, :2]),
  131. torch.min(centers + radius, tgt_box[:, 2:]),
  132. ), dim=-1)
  133. # [N, Mi, 4]
  134. center_deltas = self.get_deltas(anchors_i, center_boxes.unsqueeze(1))
  135. # [N, Mi]
  136. is_in_boxes.append(center_deltas.min(dim=-1).values > 0)
  137. # [N, M], M = M1 + M2 + ... + MF
  138. is_in_boxes = torch.cat(is_in_boxes, dim=1)
  139. else:
  140. # no center sampling, it will use all the locations within a ground-truth box
  141. # [N, M], M = M1 + M2 + ... + MF
  142. is_in_boxes = deltas.min(dim=-1).values > 0
  143. # [N, M], M = M1 + M2 + ... + MF
  144. max_deltas = deltas.max(dim=-1).values
  145. # limit the regression range for each location
  146. is_cared_in_the_level = \
  147. (max_deltas >= object_sizes_of_interest[None, :, 0]) & \
  148. (max_deltas <= object_sizes_of_interest[None, :, 1])
  149. # [N,]
  150. tgt_box_area = (tgt_box[:, 2] - tgt_box[:, 0]) * (tgt_box[:, 3] - tgt_box[:, 1])
  151. # [N,] -> [N, 1] -> [N, M]
  152. gt_positions_area = tgt_box_area.unsqueeze(1).repeat(
  153. 1, anchors_over_all_feature_maps.size(0))
  154. gt_positions_area[~is_in_boxes] = math.inf
  155. gt_positions_area[~is_cared_in_the_level] = math.inf
  156. # if there are still more than one objects for a position,
  157. # we choose the one with minimal area
  158. # [M,], each element is the index of ground-truth
  159. positions_min_area, gt_matched_idxs = gt_positions_area.min(dim=0)
  160. # ground truth box regression
  161. # [M, 4]
  162. gt_anchors_reg_deltas_i = self.get_deltas(
  163. anchors_over_all_feature_maps, tgt_box[gt_matched_idxs])
  164. # [M,]
  165. tgt_cls_i = tgt_cls[gt_matched_idxs]
  166. # anchors with area inf are treated as background.
  167. tgt_cls_i[positions_min_area == math.inf] = self.num_classes
  168. # ground truth centerness
  169. left_right = gt_anchors_reg_deltas_i[:, [0, 2]]
  170. top_bottom = gt_anchors_reg_deltas_i[:, [1, 3]]
  171. # [M,]
  172. gt_centerness_i = torch.sqrt(
  173. (left_right.min(dim=-1).values / left_right.max(dim=-1).values).clamp_(min=0)
  174. * (top_bottom.min(dim=-1).values / top_bottom.max(dim=-1).values).clamp_(min=0)
  175. )
  176. gt_classes.append(tgt_cls_i)
  177. gt_anchors_deltas.append(gt_anchors_reg_deltas_i)
  178. gt_centerness.append(gt_centerness_i)
  179. del centers, center_boxes, deltas, max_deltas, center_deltas
  180. else:
  181. tgt_cls_i = torch.zeros(anchors_over_all_feature_maps.shape[0], device=device) + self.num_classes
  182. gt_anchors_reg_deltas_i = torch.zeros([anchors_over_all_feature_maps.shape[0], 4], device=device)
  183. gt_centerness_i = torch.zeros(anchors_over_all_feature_maps.shape[0], device=device)
  184. gt_classes.append(tgt_cls_i.long())
  185. gt_anchors_deltas.append(gt_anchors_reg_deltas_i.float())
  186. gt_centerness.append(gt_centerness_i.float())
  187. # [B, M], [B, M, 4], [B, M]
  188. return torch.stack(gt_classes), torch.stack(gt_anchors_deltas), torch.stack(gt_centerness)
  189. class AlignedOTAMatcher(object):
  190. """
  191. This code referenced to https://github.com/open-mmlab/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py
  192. """
  193. def __init__(self, num_classes, soft_center_radius=3.0, topk_candidates=13):
  194. self.num_classes = num_classes
  195. self.soft_center_radius = soft_center_radius
  196. self.topk_candidates = topk_candidates
  197. @torch.no_grad()
  198. def __call__(self,
  199. fpn_strides,
  200. anchors,
  201. pred_cls,
  202. pred_box,
  203. gt_labels,
  204. gt_bboxes):
  205. # [M,]
  206. strides = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
  207. for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
  208. # List[F, M, 2] -> [M, 2]
  209. num_gt = len(gt_labels)
  210. anchors = torch.cat(anchors, dim=0)
  211. # check gt
  212. if num_gt == 0 or gt_bboxes.max().item() == 0.:
  213. return {
  214. 'assigned_labels': gt_labels.new_full(pred_cls[..., 0].shape,
  215. self.num_classes,
  216. dtype=torch.long),
  217. 'assigned_bboxes': gt_bboxes.new_full(pred_box.shape, 0),
  218. 'assign_metrics': gt_bboxes.new_full(pred_cls[..., 0].shape, 0)
  219. }
  220. # get inside points: [N, M]
  221. is_in_gt = self.find_inside_points(gt_bboxes, anchors)
  222. valid_mask = is_in_gt.sum(dim=0) > 0 # [M,]
  223. # ----------------------------------- soft center prior -----------------------------------
  224. gt_center = (gt_bboxes[..., :2] + gt_bboxes[..., 2:]) / 2.0
  225. distance = (anchors.unsqueeze(0) - gt_center.unsqueeze(1)
  226. ).pow(2).sum(-1).sqrt() / strides.unsqueeze(0) # [N, M]
  227. distance = distance * valid_mask.unsqueeze(0)
  228. soft_center_prior = torch.pow(10, distance - self.soft_center_radius)
  229. # ----------------------------------- regression cost -----------------------------------
  230. pair_wise_ious, _ = box_iou(gt_bboxes, pred_box) # [N, M]
  231. pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) * 3.0
  232. # ----------------------------------- classification cost -----------------------------------
  233. ## select the predicted scores corresponded to the gt_labels
  234. pairwise_pred_scores = pred_cls.permute(1, 0) # [M, C] -> [C, M]
  235. pairwise_pred_scores = pairwise_pred_scores[gt_labels.long(), :].float() # [N, M]
  236. ## scale factor
  237. scale_factor = (pair_wise_ious - pairwise_pred_scores.sigmoid()).abs().pow(2.0)
  238. ## cls cost
  239. pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
  240. pairwise_pred_scores, pair_wise_ious,
  241. reduction="none") * scale_factor # [N, M]
  242. del pairwise_pred_scores
  243. ## foreground cost matrix
  244. cost_matrix = pair_wise_cls_loss + pair_wise_ious_loss + soft_center_prior
  245. max_pad_value = torch.ones_like(cost_matrix) * 1e9
  246. cost_matrix = torch.where(valid_mask[None].repeat(num_gt, 1), # [N, M]
  247. cost_matrix, max_pad_value)
  248. # ----------------------------------- dynamic label assignment -----------------------------------
  249. matched_pred_ious, matched_gt_inds, fg_mask_inboxes = self.dynamic_k_matching(
  250. cost_matrix, pair_wise_ious, num_gt)
  251. del pair_wise_cls_loss, cost_matrix, pair_wise_ious, pair_wise_ious_loss
  252. # -----------------------------------process assigned labels -----------------------------------
  253. assigned_labels = gt_labels.new_full(pred_cls[..., 0].shape,
  254. self.num_classes) # [M,]
  255. assigned_labels[fg_mask_inboxes] = gt_labels[matched_gt_inds].squeeze(-1)
  256. assigned_labels = assigned_labels.long() # [M,]
  257. assigned_bboxes = gt_bboxes.new_full(pred_box.shape, 0) # [M, 4]
  258. assigned_bboxes[fg_mask_inboxes] = gt_bboxes[matched_gt_inds] # [M, 4]
  259. assign_metrics = gt_bboxes.new_full(pred_cls[..., 0].shape, 0) # [M,]
  260. assign_metrics[fg_mask_inboxes] = matched_pred_ious # [M,]
  261. assigned_dict = dict(
  262. assigned_labels=assigned_labels,
  263. assigned_bboxes=assigned_bboxes,
  264. assign_metrics=assign_metrics
  265. )
  266. return assigned_dict
  267. def find_inside_points(self, gt_bboxes, anchors):
  268. """
  269. gt_bboxes: Tensor -> [N, 2]
  270. anchors: Tensor -> [M, 2]
  271. """
  272. num_anchors = anchors.shape[0]
  273. num_gt = gt_bboxes.shape[0]
  274. anchors_expand = anchors.unsqueeze(0).repeat(num_gt, 1, 1) # [N, M, 2]
  275. gt_bboxes_expand = gt_bboxes.unsqueeze(1).repeat(1, num_anchors, 1) # [N, M, 4]
  276. # offset
  277. lt = anchors_expand - gt_bboxes_expand[..., :2]
  278. rb = gt_bboxes_expand[..., 2:] - anchors_expand
  279. bbox_deltas = torch.cat([lt, rb], dim=-1)
  280. is_in_gts = bbox_deltas.min(dim=-1).values > 0
  281. return is_in_gts
  282. def dynamic_k_matching(self, cost_matrix, pairwise_ious, num_gt):
  283. """Use IoU and matching cost to calculate the dynamic top-k positive
  284. targets.
  285. Args:
  286. cost_matrix (Tensor): Cost matrix.
  287. pairwise_ious (Tensor): Pairwise iou matrix.
  288. num_gt (int): Number of gt.
  289. valid_mask (Tensor): Mask for valid bboxes.
  290. Returns:
  291. tuple: matched ious and gt indexes.
  292. """
  293. matching_matrix = torch.zeros_like(cost_matrix, dtype=torch.uint8)
  294. # select candidate topk ious for dynamic-k calculation
  295. candidate_topk = min(self.topk_candidates, pairwise_ious.size(1))
  296. topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=1)
  297. # calculate dynamic k for each gt
  298. dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
  299. # sorting the batch cost matirx is faster than topk
  300. _, sorted_indices = torch.sort(cost_matrix, dim=1)
  301. for gt_idx in range(num_gt):
  302. topk_ids = sorted_indices[gt_idx, :dynamic_ks[gt_idx]]
  303. matching_matrix[gt_idx, :][topk_ids] = 1
  304. del topk_ious, dynamic_ks, topk_ids
  305. prior_match_gt_mask = matching_matrix.sum(0) > 1
  306. if prior_match_gt_mask.sum() > 0:
  307. cost_min, cost_argmin = torch.min(
  308. cost_matrix[:, prior_match_gt_mask], dim=0)
  309. matching_matrix[:, prior_match_gt_mask] *= 0
  310. matching_matrix[cost_argmin, prior_match_gt_mask] = 1
  311. # get foreground mask inside box and center prior
  312. fg_mask_inboxes = matching_matrix.sum(0) > 0
  313. matched_pred_ious = (matching_matrix *
  314. pairwise_ious).sum(0)[fg_mask_inboxes]
  315. matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
  316. return matched_pred_ious, matched_gt_inds, fg_mask_inboxes