matcher.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import numpy as np
  2. import torch
  3. from torch import nn
  4. from utils.box_ops import *
  5. class UniformMatcher(nn.Module):
  6. """
  7. This code is referenced to https://github.com/megvii-model/YOLOF/blob/main/playground/detection/coco/yolof/yolof_base/uniform_matcher.py
  8. """
  9. def __init__(self, match_times: int = 4):
  10. super().__init__()
  11. self.match_times = match_times
  12. @torch.no_grad()
  13. def forward(self, pred_boxes, anchor_boxes, targets):
  14. """
  15. pred_boxes: (Tensor) -> [B, num_queries, 4]
  16. anchor_boxes: (Tensor) -> [num_queries, 4]
  17. targets: (Dict) -> dict{'boxes': [...], 'labels': [...]}
  18. """
  19. bs, num_queries = pred_boxes.shape[:2]
  20. # We flatten to compute the cost matrices in a batch
  21. # [B, num_queries, 4] -> [M, 4]
  22. out_bbox = pred_boxes.flatten(0, 1)
  23. # [num_queries, 4] -> [1, num_queries, 4] -> [B, num_queries, 4] -> [M, 4]
  24. anchor_boxes = anchor_boxes[None].repeat(bs, 1, 1)
  25. anchor_boxes = anchor_boxes.flatten(0, 1)
  26. # Also concat the target boxes
  27. tgt_bbox = torch.cat([v['boxes'] for v in targets])
  28. # Compute the L1 cost between boxes
  29. # Note that we use anchors and predict boxes both
  30. cost_bbox = torch.cdist(box_xyxy_to_cxcywh(out_bbox),
  31. box_xyxy_to_cxcywh(tgt_bbox),
  32. p=1)
  33. cost_bbox_anchors = torch.cdist(anchor_boxes,
  34. box_xyxy_to_cxcywh(tgt_bbox),
  35. p=1)
  36. # Final cost matrix: [B, M, N], M=num_queries, N=num_tgt
  37. C = cost_bbox
  38. C = C.view(bs, num_queries, -1).cpu()
  39. C1 = cost_bbox_anchors
  40. C1 = C1.view(bs, num_queries, -1).cpu()
  41. sizes = [len(v['boxes']) for v in targets] # the number of object instances in each image
  42. all_indices_list = [[] for _ in range(bs)]
  43. # positive indices when matching predict boxes and gt boxes
  44. # len(indices) = batch size
  45. # len(tupe) = topk
  46. indices = [
  47. tuple(
  48. torch.topk(
  49. c[i],
  50. k=self.match_times,
  51. dim=0,
  52. largest=False)[1].numpy().tolist()
  53. )
  54. for i, c in enumerate(C.split(sizes, -1))
  55. ]
  56. # positive indices when matching anchor boxes and gt boxes
  57. indices1 = [
  58. tuple(
  59. torch.topk(
  60. c[i],
  61. k=self.match_times,
  62. dim=0,
  63. largest=False)[1].numpy().tolist())
  64. for i, c in enumerate(C1.split(sizes, -1))]
  65. # concat the indices according to image ids
  66. # img_id = batch_id
  67. for img_id, (idx, idx1) in enumerate(zip(indices, indices1)):
  68. img_idx_i = [
  69. np.array(idx_ + idx1_)
  70. for (idx_, idx1_) in zip(idx, idx1)
  71. ] # 'i' is the index of queris
  72. img_idx_j = [
  73. np.array(list(range(len(idx_))) + list(range(len(idx1_))))
  74. for (idx_, idx1_) in zip(idx, idx1)
  75. ] # 'j' is the index of tgt
  76. all_indices_list[img_id] = [*zip(img_idx_i, img_idx_j)]
  77. # re-organize the positive indices
  78. all_indices = []
  79. for img_id in range(bs):
  80. all_idx_i = []
  81. all_idx_j = []
  82. for idx_list in all_indices_list[img_id]:
  83. idx_i, idx_j = idx_list
  84. all_idx_i.append(idx_i)
  85. all_idx_j.append(idx_j)
  86. all_idx_i = np.hstack(all_idx_i)
  87. all_idx_j = np.hstack(all_idx_j)
  88. all_indices.append((all_idx_i, all_idx_j))
  89. return [(torch.as_tensor(i, dtype=torch.int64),
  90. torch.as_tensor(j, dtype=torch.int64)) for i, j in all_indices]