matcher.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # -*- coding: utf-8 -*-
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. # Modified by BaseDetection, Inc. and its affiliates.
  4. import torch
  5. from utils.box_ops import box_iou
  6. class RetinaNetMatcher(object):
  7. """
  8. This class assigns to each predicted "element" (e.g., a box) a ground-truth
  9. element. Each predicted element will have exactly zero or one matches; each
  10. ground-truth element may be matched to zero or more predicted elements.
  11. The matching is determined by the MxN match_quality_matrix, that characterizes
  12. how well each (ground-truth, prediction)-pair match each other. For example,
  13. if the elements are boxes, this matrix may contain box intersection-over-union
  14. overlap values.
  15. The matcher returns (a) a vector of length N containing the index of the
  16. ground-truth element m in [0, M) that matches to prediction n in [0, N).
  17. (b) a vector of length N containing the labels for each prediction.
  18. """
  19. def __init__(self,
  20. num_classes,
  21. iou_threshold,
  22. iou_labels,
  23. allow_low_quality_matches=False):
  24. """
  25. Args:
  26. thresholds (list): a list of thresholds used to stratify predictions
  27. into levels.
  28. labels (list): a list of values to label predictions belonging at
  29. each level. A label can be one of {-1, 0, 1} signifying
  30. {ignore, negative class, positive class}, respectively.
  31. allow_low_quality_matches (bool): if True, produce additional matches
  32. for predictions with maximum match quality lower than high_threshold.
  33. See set_low_quality_matches_ for more details.
  34. For example,
  35. thresholds = [0.3, 0.5]
  36. labels = [0, -1, 1]
  37. All predictions with iou < 0.3 will be marked with 0 and
  38. thus will be considered as false positives while training.
  39. All predictions with 0.3 <= iou < 0.5 will be marked with -1 and
  40. thus will be ignored.
  41. All predictions with 0.5 <= iou will be marked with 1 and
  42. thus will be considered as true positives.
  43. """
  44. self.num_classes = num_classes
  45. # Add -inf and +inf to first and last position in iou_thresholdhreshold
  46. iou_threshold = iou_threshold[:]
  47. assert iou_threshold[0] > 0
  48. iou_threshold.insert(0, -float("inf"))
  49. iou_threshold.append(float("inf"))
  50. assert all(low <= high for (low, high) in zip(iou_threshold[:-1], iou_threshold[1:]))
  51. assert all(label in [-1, 0, 1] for label in iou_labels)
  52. assert len(iou_labels) == len(iou_threshold) - 1
  53. self.iou_threshold = iou_threshold
  54. self.iou_labels = iou_labels
  55. self.allow_low_quality_matches = allow_low_quality_matches
  56. @torch.no_grad()
  57. def __call__(self, anchors, targets):
  58. """
  59. anchors: (Tensor) [B, M, 4] (x1, y1, x2, y2)
  60. targets: (Dict) dict{'boxes': [...],
  61. 'labels': [...],
  62. 'orig_size': ...}
  63. """
  64. # list[Tensor(R, 4)], one for each image
  65. gt_classes = []
  66. gt_boxes = []
  67. device = anchors.device
  68. for anchors_per_image, targets_per_image in zip(anchors, targets):
  69. # [N,]
  70. tgt_labels = targets_per_image['labels'].to(device)
  71. # [N, 4]
  72. tgt_boxes = targets_per_image['boxes'].to(device)
  73. # [N, M], N is the number of targets, M is the number of anchors
  74. match_quality_matrix, _ = box_iou(tgt_boxes, anchors_per_image)
  75. gt_matched_idxs, anchor_labels = self.matching(match_quality_matrix)
  76. has_gt = len(tgt_labels) > 0
  77. if has_gt:
  78. # ground truth box regression
  79. matched_gt_boxes = tgt_boxes[gt_matched_idxs]
  80. gt_classes_i = tgt_labels[gt_matched_idxs]
  81. # Anchors with label 0 are treated as background.
  82. gt_classes_i[anchor_labels == 0] = self.num_classes
  83. # Anchors with label -1 are ignored.
  84. gt_classes_i[anchor_labels == -1] = -1
  85. else:
  86. gt_classes_i = torch.zeros_like(gt_matched_idxs) + self.num_classes
  87. matched_gt_boxes = torch.zeros_like(anchors_per_image)
  88. gt_classes.append(gt_classes_i)
  89. gt_boxes.append(matched_gt_boxes)
  90. return torch.stack(gt_classes), torch.stack(gt_boxes)
  91. def matching(self, match_quality_matrix):
  92. """
  93. Args:
  94. match_quality_matrix (Tensor[float]): an N x M tensor, containing the
  95. pairwise quality between N ground-truth elements and M predicted
  96. elements. All elements must be >= 0 (due to the us of `torch.nonzero`
  97. for selecting indices in :meth:`set_low_quality_matches_`).
  98. Returns:
  99. matches (Tensor[int64]): a vector of length M, where matches[i] is a matched
  100. ground-truth index in [0, N)
  101. match_labels (Tensor[int8]): a vector of length M, where pred_labels[i] indicates
  102. whether a prediction is a true or false positive or ignored
  103. """
  104. assert match_quality_matrix.dim() == 2
  105. if match_quality_matrix.numel() == 0:
  106. default_matches = match_quality_matrix.new_full(
  107. (match_quality_matrix.size(1),), 0, dtype=torch.int64
  108. )
  109. # When no gt boxes exist, we define IOU = 0 and therefore set labels
  110. # to `self.labels[0]`, which usually defaults to background class 0
  111. # To choose to ignore instead, can make labels=[-1,0,-1,1] + set appropriate thresholds
  112. default_match_labels = match_quality_matrix.new_full(
  113. (match_quality_matrix.size(1),), self.iou_labels[0], dtype=torch.int8
  114. )
  115. return default_matches, default_match_labels
  116. assert torch.all(match_quality_matrix >= 0)
  117. # match_quality_matrix is N (gt) x M (predicted)
  118. # Max over gt elements (dim 0) to find best gt candidate for each prediction
  119. matched_vals, matches = match_quality_matrix.max(dim=0)
  120. match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
  121. for (l, low, high) in zip(self.iou_labels, self.iou_threshold[:-1], self.iou_threshold[1:]):
  122. low_high = (matched_vals >= low) & (matched_vals < high)
  123. match_labels[low_high] = l
  124. if self.allow_low_quality_matches:
  125. self.set_low_quality_matches_(match_labels, match_quality_matrix)
  126. return matches, match_labels
  127. def set_low_quality_matches_(self, match_labels, match_quality_matrix):
  128. """
  129. Produce additional matches for predictions that have only low-quality matches.
  130. Specifically, for each ground-truth G find the set of predictions that have
  131. maximum overlap with it (including ties); for each prediction in that set, if
  132. it is unmatched, then match it to the ground-truth G.
  133. This function implements the RPN assignment case (i) in Sec. 3.1.2 of the
  134. Faster R-CNN paper: https://arxiv.org/pdf/1506.01497v3.pdf.
  135. """
  136. # For each gt, find the prediction with which it has highest quality
  137. highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
  138. # Find the highest quality match available, even if it is low, including ties.
  139. # Note that the matches qualities must be positive due to the use of
  140. # `torch.nonzero`.
  141. gt_pred_pairs_of_highest_quality = torch.nonzero(
  142. match_quality_matrix == highest_quality_foreach_gt[:, None],
  143. as_tuple=False
  144. )
  145. # Example gt_pred_pairs_of_highest_quality:
  146. # tensor([[ 0, 39796],
  147. # [ 1, 32055],
  148. # [ 1, 32070],
  149. # [ 2, 39190],
  150. # [ 2, 40255],
  151. # [ 3, 40390],
  152. # [ 3, 41455],
  153. # [ 4, 45470],
  154. # [ 5, 45325],
  155. # [ 5, 46390]])
  156. # Each row is a (gt index, prediction index)
  157. # Note how gt items 1, 2, 3, and 5 each have two ties
  158. pred_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1]
  159. match_labels[pred_inds_to_update] = 1