matcher.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import torch
  2. import numpy as np
  3. class Yolov2Matcher(object):
  4. def __init__(self, iou_thresh, num_classes, anchor_size):
  5. self.num_classes = num_classes
  6. self.iou_thresh = iou_thresh
  7. # anchor box
  8. self.num_anchors = len(anchor_size)
  9. self.anchor_size = anchor_size
  10. self.anchor_boxes = np.array(
  11. [[0., 0., anchor[0], anchor[1]]
  12. for anchor in anchor_size]
  13. ) # [KA, 4]
  14. def compute_iou(self, anchor_boxes, gt_box):
  15. """
  16. anchor_boxes : ndarray -> [KA, 4] (cx, cy, bw, bh).
  17. gt_box : ndarray -> [1, 4] (cx, cy, bw, bh).
  18. """
  19. # anchors: [KA, 4]
  20. anchors = np.zeros_like(anchor_boxes)
  21. anchors[..., :2] = anchor_boxes[..., :2] - anchor_boxes[..., 2:] * 0.5 # x1y1
  22. anchors[..., 2:] = anchor_boxes[..., :2] + anchor_boxes[..., 2:] * 0.5 # x2y2
  23. anchors_area = anchor_boxes[..., 2] * anchor_boxes[..., 3]
  24. # gt_box: [1, 4] -> [KA, 4]
  25. gt_box = np.array(gt_box).reshape(-1, 4)
  26. gt_box = np.repeat(gt_box, anchors.shape[0], axis=0)
  27. gt_box_ = np.zeros_like(gt_box)
  28. gt_box_[..., :2] = gt_box[..., :2] - gt_box[..., 2:] * 0.5 # x1y1
  29. gt_box_[..., 2:] = gt_box[..., :2] + gt_box[..., 2:] * 0.5 # x2y2
  30. gt_box_area = np.prod(gt_box[..., 2:] - gt_box[..., :2], axis=1)
  31. # intersection
  32. inter_w = np.minimum(anchors[:, 2], gt_box_[:, 2]) - \
  33. np.maximum(anchors[:, 0], gt_box_[:, 0])
  34. inter_h = np.minimum(anchors[:, 3], gt_box_[:, 3]) - \
  35. np.maximum(anchors[:, 1], gt_box_[:, 1])
  36. inter_area = inter_w * inter_h
  37. # union
  38. union_area = anchors_area + gt_box_area - inter_area
  39. # iou
  40. iou = inter_area / union_area
  41. iou = np.clip(iou, a_min=1e-10, a_max=1.0)
  42. return iou
  43. @torch.no_grad()
  44. def __call__(self, fmp_size, stride, targets):
  45. """
  46. img_size: (Int) input image size
  47. stride: (Int) -> stride of YOLOv1 output.
  48. targets: (Dict) dict{'boxes': [...],
  49. 'labels': [...],
  50. 'orig_size': ...}
  51. """
  52. # prepare
  53. bs = len(targets)
  54. fmp_h, fmp_w = fmp_size
  55. gt_objectness = np.zeros([bs, fmp_h, fmp_w, self.num_anchors, 1])
  56. gt_classes = np.zeros([bs, fmp_h, fmp_w, self.num_anchors, self.num_classes])
  57. gt_bboxes = np.zeros([bs, fmp_h, fmp_w, self.num_anchors, 4])
  58. for batch_index in range(bs):
  59. targets_per_image = targets[batch_index]
  60. # [N,]
  61. tgt_cls = targets_per_image["labels"].numpy()
  62. # [N, 4]
  63. tgt_box = targets_per_image['boxes'].numpy()
  64. for gt_box, gt_label in zip(tgt_box, tgt_cls):
  65. x1, y1, x2, y2 = gt_box
  66. # xyxy -> cxcywh
  67. xc, yc = (x2 + x1) * 0.5, (y2 + y1) * 0.5
  68. bw, bh = x2 - x1, y2 - y1
  69. gt_box = [0, 0, bw, bh]
  70. # check
  71. if bw < 1. or bh < 1.:
  72. continue
  73. # compute IoU
  74. iou = self.compute_iou(self.anchor_boxes, gt_box)
  75. iou_mask = (iou > self.iou_thresh)
  76. label_assignment_results = []
  77. if iou_mask.sum() == 0:
  78. # We assign the anchor box with highest IoU score.
  79. iou_ind = np.argmax(iou)
  80. anchor_idx = iou_ind
  81. # compute the grid cell
  82. xc_s = xc / stride
  83. yc_s = yc / stride
  84. grid_x = int(xc_s)
  85. grid_y = int(yc_s)
  86. label_assignment_results.append([grid_x, grid_y, anchor_idx])
  87. else:
  88. for iou_ind, iou_m in enumerate(iou_mask):
  89. if iou_m:
  90. anchor_idx = iou_ind
  91. # compute the gride cell
  92. xc_s = xc / stride
  93. yc_s = yc / stride
  94. grid_x = int(xc_s)
  95. grid_y = int(yc_s)
  96. label_assignment_results.append([grid_x, grid_y, anchor_idx])
  97. # label assignment
  98. for result in label_assignment_results:
  99. grid_x, grid_y, anchor_idx = result
  100. if grid_x < fmp_w and grid_y < fmp_h:
  101. # obj
  102. gt_objectness[batch_index, grid_y, grid_x, anchor_idx] = 1.0
  103. # cls
  104. cls_ont_hot = np.zeros(self.num_classes)
  105. cls_ont_hot[int(gt_label)] = 1.0
  106. gt_classes[batch_index, grid_y, grid_x, anchor_idx] = cls_ont_hot
  107. # box
  108. gt_bboxes[batch_index, grid_y, grid_x, anchor_idx] = np.array([x1, y1, x2, y2])
  109. # [B, H, W, A, C] -> [B, HWA, C]
  110. gt_objectness = gt_objectness.reshape(bs, -1, 1)
  111. gt_classes = gt_classes.reshape(bs, -1, self.num_classes)
  112. gt_bboxes = gt_bboxes.reshape(bs, -1, 4)
  113. # to tensor
  114. gt_objectness = torch.from_numpy(gt_objectness).float()
  115. gt_classes = torch.from_numpy(gt_classes).float()
  116. gt_bboxes = torch.from_numpy(gt_bboxes).float()
  117. return gt_objectness, gt_classes, gt_bboxes