matcher.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import torch
  2. import numpy as np
  3. class YoloMatcher(object):
  4. def __init__(self, num_classes):
  5. self.num_classes = num_classes
  6. def generate_dxdywh(self, gt_box, img_size, stride):
  7. x1, y1, x2, y2 = gt_box
  8. # xyxy -> cxcywh
  9. xc, yc = (x2 + x1) * 0.5, (y2 + y1) * 0.5
  10. bw, bh = x2 - x1, y2 - y1
  11. # 检查数据的有效性
  12. if bw < 1. or bh < 1.:
  13. return False
  14. # 计算中心点所在的网格坐标
  15. xs_c = xc / stride
  16. ys_c = yc / stride
  17. grid_x = int(xs_c)
  18. grid_y = int(ys_c)
  19. # 计算中心点偏移量和宽高的标签
  20. tx = xs_c - grid_x
  21. ty = ys_c - grid_y
  22. tw = np.log(bw)
  23. th = np.log(bh)
  24. # 计算边界框位置参数的损失权重
  25. weight = 2.0 - (bh / img_size[0]) * (bw / img_size[1])
  26. return grid_x, grid_y, tx, ty, tw, th, weight
  27. @torch.no_grad()
  28. def __call__(self, img_size, stride, targets):
  29. """
  30. img_size: (Int) input image size
  31. stride: (Int) -> stride of YOLOv1 output.
  32. targets: (Dict) dict{'boxes': [...],
  33. 'labels': [...],
  34. 'orig_size': ...}
  35. """
  36. # prepare
  37. bs = len(targets)
  38. fmp_h, fmp_w = img_size[0] // stride, img_size[1] // stride
  39. gt_objectness = np.zeros([bs, fmp_h, fmp_w, 1])
  40. gt_labels = np.zeros([bs, fmp_h, fmp_w, 1])
  41. gt_bboxes = np.zeros([bs, fmp_h, fmp_w, 4])
  42. gt_box_weight = np.zeros([bs, fmp_h, fmp_w, 1])
  43. for batch_index in range(bs):
  44. targets_per_image = targets[batch_index]
  45. # [N,]
  46. tgt_cls = targets_per_image["labels"].numpy()
  47. # [N, 4]
  48. tgt_box = targets_per_image['boxes'].numpy()
  49. for gt_box, gt_label in zip(tgt_box, tgt_cls):
  50. result = self.generate_dxdywh(gt_box, img_size, stride)
  51. if result:
  52. grid_x, grid_y, tx, ty, tw, th, weight = result
  53. if grid_x < fmp_w and grid_y < fmp_h:
  54. gt_objectness[batch_index, grid_y, grid_x] = 1.0
  55. gt_labels[batch_index, grid_y, grid_x] = gt_label
  56. gt_bboxes[batch_index, grid_y, grid_x] = np.array([tx, ty, tw, th])
  57. gt_box_weight[batch_index, grid_y, grid_x] = weight
  58. # [B, M, C]
  59. gt_objectness = gt_objectness.reshape(bs, -1, 1)
  60. gt_labels = gt_labels.reshape(bs, -1, 1)
  61. gt_bboxes = gt_bboxes.reshape(bs, -1, 4)
  62. gt_box_weight = gt_box_weight.reshape(bs, -1, 1)
  63. # to tensor
  64. gt_objectness = torch.from_numpy(gt_objectness).float()
  65. gt_labels = torch.from_numpy(gt_labels).long()
  66. gt_bboxes = torch.from_numpy(gt_bboxes).float()
  67. gt_box_weight = torch.from_numpy(gt_box_weight).float()
  68. return gt_objectness, gt_labels, gt_bboxes, gt_box_weight