matcher.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # ---------------------------------------------------------------------
  2. # Copyright (c) Megvii Inc. All rights reserved.
  3. # ---------------------------------------------------------------------
  4. import numpy as np
  5. import torch
  6. from torch import nn
  7. from utils.box_ops import *
  8. class UniformMatcher(nn.Module):
  9. """
  10. This code referenced to https://github.com/megvii-model/YOLOF/blob/main/playground/detection/coco/yolof/yolof_base/uniform_matcher.py
  11. Uniform Matching between the anchors and gt boxes, which can achieve
  12. balance in positive anchors.
  13. Args:
  14. match_times(int): Number of positive anchors for each gt box.
  15. """
  16. def __init__(self, match_times: int = 4):
  17. super().__init__()
  18. self.match_times = match_times
  19. @torch.no_grad()
  20. def forward(self, pred_boxes, anchor_boxes, targets):
  21. """
  22. pred_boxes: (Tensor) -> [B, num_queries, 4]
  23. anchor_boxes: (Tensor) -> [num_queries, 4]
  24. targets: (Dict) -> dict{'boxes': [...], 'labels': [...]}
  25. """
  26. bs, num_queries = pred_boxes.shape[:2]
  27. # We flatten to compute the cost matrices in a batch
  28. # [B, num_queries, 4] -> [M, 4]
  29. out_bbox = pred_boxes.flatten(0, 1)
  30. # [num_queries, 4] -> [1, num_queries, 4] -> [B, num_queries, 4] -> [M, 4]
  31. anchor_boxes = anchor_boxes[None].repeat(bs, 1, 1)
  32. anchor_boxes = anchor_boxes.flatten(0, 1)
  33. # Also concat the target boxes
  34. tgt_bbox = torch.cat([v['boxes'] for v in targets])
  35. # Compute the L1 cost between boxes
  36. # Note that we use anchors and predict boxes both
  37. cost_bbox = torch.cdist(box_xyxy_to_cxcywh(out_bbox),
  38. box_xyxy_to_cxcywh(tgt_bbox),
  39. p=1)
  40. cost_bbox_anchors = torch.cdist(anchor_boxes,
  41. box_xyxy_to_cxcywh(tgt_bbox),
  42. p=1)
  43. # Final cost matrix: [B, M, N], M=num_queries, N=num_tgt
  44. C = cost_bbox
  45. C = C.view(bs, num_queries, -1).cpu()
  46. C1 = cost_bbox_anchors
  47. C1 = C1.view(bs, num_queries, -1).cpu()
  48. sizes = [len(v['boxes']) for v in targets] # the number of object instances in each image
  49. all_indices_list = [[] for _ in range(bs)]
  50. # positive indices when matching predict boxes and gt boxes
  51. # len(indices) = batch size
  52. # len(tupe) = topk
  53. indices = [
  54. tuple(
  55. torch.topk(
  56. c[i],
  57. k=self.match_times,
  58. dim=0,
  59. largest=False)[1].numpy().tolist()
  60. )
  61. for i, c in enumerate(C.split(sizes, -1))
  62. ]
  63. # positive indices when matching anchor boxes and gt boxes
  64. indices1 = [
  65. tuple(
  66. torch.topk(
  67. c[i],
  68. k=self.match_times,
  69. dim=0,
  70. largest=False)[1].numpy().tolist())
  71. for i, c in enumerate(C1.split(sizes, -1))]
  72. # concat the indices according to image ids
  73. # img_id = batch_id
  74. for img_id, (idx, idx1) in enumerate(zip(indices, indices1)):
  75. img_idx_i = [
  76. np.array(idx_ + idx1_)
  77. for (idx_, idx1_) in zip(idx, idx1)
  78. ] # 'i' is the index of queris
  79. img_idx_j = [
  80. np.array(list(range(len(idx_))) + list(range(len(idx1_))))
  81. for (idx_, idx1_) in zip(idx, idx1)
  82. ] # 'j' is the index of tgt
  83. all_indices_list[img_id] = [*zip(img_idx_i, img_idx_j)]
  84. # re-organize the positive indices
  85. all_indices = []
  86. for img_id in range(bs):
  87. all_idx_i = []
  88. all_idx_j = []
  89. for idx_list in all_indices_list[img_id]:
  90. idx_i, idx_j = idx_list
  91. all_idx_i.append(idx_i)
  92. all_idx_j.append(idx_j)
  93. all_idx_i = np.hstack(all_idx_i)
  94. all_idx_j = np.hstack(all_idx_j)
  95. all_indices.append((all_idx_i, all_idx_j))
  96. return [(torch.as_tensor(i, dtype=torch.int64),
  97. torch.as_tensor(j, dtype=torch.int64)) for i, j in all_indices]