matcher.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import torch
  2. import numpy as np
  3. class Yolov2Matcher(object):
  4. def __init__(self, num_classes, anchor_size):
  5. self.num_classes = num_classes
  6. self.anchor_size = anchor_size
  7. @torch.no_grad()
  8. def __call__(self, fmp_size, stride, targets):
  9. """
  10. img_size: (Int) input image size
  11. stride: (Int) -> stride of YOLOv1 output.
  12. targets: (Dict) dict{'boxes': [...],
  13. 'labels': [...],
  14. 'orig_size': ...}
  15. """
  16. # prepare
  17. bs = len(targets)
  18. fmp_h, fmp_w = fmp_size
  19. gt_objectness = np.zeros([bs, fmp_h, fmp_w, 1])
  20. gt_classes = np.zeros([bs, fmp_h, fmp_w, self.num_classes])
  21. gt_bboxes = np.zeros([bs, fmp_h, fmp_w, 4])
  22. for batch_index in range(bs):
  23. targets_per_image = targets[batch_index]
  24. # [N,]
  25. tgt_cls = targets_per_image["labels"].numpy()
  26. # [N, 4]
  27. tgt_box = targets_per_image['boxes'].numpy()
  28. for gt_box, gt_label in zip(tgt_box, tgt_cls):
  29. x1, y1, x2, y2 = gt_box
  30. # xyxy -> cxcywh
  31. xc, yc = (x2 + x1) * 0.5, (y2 + y1) * 0.5
  32. bw, bh = x2 - x1, y2 - y1
  33. # check
  34. if bw < 1. or bh < 1.:
  35. return False
  36. # grid
  37. xs_c = xc / stride
  38. ys_c = yc / stride
  39. grid_x = int(xs_c)
  40. grid_y = int(ys_c)
  41. if grid_x < fmp_w and grid_y < fmp_h:
  42. gt_objectness[batch_index, grid_y, grid_x] = 1.0
  43. gt_classes[batch_index, grid_y, grid_x, int(gt_label)] = 1.0
  44. gt_bboxes[batch_index, grid_y, grid_x] = np.array([x1, y1, x2, y2])
  45. # [B, M, C]
  46. gt_objectness = gt_objectness.reshape(bs, -1, 1)
  47. gt_classes = gt_classes.reshape(bs, -1, self.num_classes)
  48. gt_bboxes = gt_bboxes.reshape(bs, -1, 4)
  49. # to tensor
  50. gt_objectness = torch.from_numpy(gt_objectness).float()
  51. gt_classes = torch.from_numpy(gt_classes).float()
  52. gt_bboxes = torch.from_numpy(gt_bboxes).float()
  53. return gt_objectness, gt_classes, gt_bboxes