matcher.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # ------------------------------------------------------------------------
  2. # Plain-DETR
  3. # Copyright (c) 2023 Xi'an Jiaotong University & Microsoft Research Asia.
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # ------------------------------------------------------------------------
  6. # Deformable DETR
  7. # Copyright (c) 2020 SenseTime. All Rights Reserved.
  8. # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
  9. # ------------------------------------------------------------------------
  10. # Modified from DETR (https://github.com/facebookresearch/detr)
  11. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  12. # ------------------------------------------------------------------------
  13. """
  14. Modules to compute the matching cost and solve the corresponding LSAP.
  15. """
  16. import torch
  17. from scipy.optimize import linear_sum_assignment
  18. from torch import nn
  19. from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou, bbox2delta
  20. class HungarianMatcher(nn.Module):
  21. def __init__(self,
  22. cost_class: float = 1,
  23. cost_bbox: float = 1,
  24. cost_giou: float = 1,
  25. ):
  26. super().__init__()
  27. self.cost_class = cost_class
  28. self.cost_bbox = cost_bbox
  29. self.cost_giou = cost_giou
  30. assert (
  31. cost_class != 0 or cost_bbox != 0 or cost_giou != 0
  32. ), "all costs cant be 0"
  33. def forward(self, outputs, targets):
  34. """ Performs the matching
  35. Params:
  36. outputs: This is a dict that contains at least these entries:
  37. "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
  38. "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
  39. targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
  40. "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
  41. objects in the target) containing the class labels
  42. "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
  43. Returns:
  44. A list of size batch_size, containing tuples of (index_i, index_j) where:
  45. - index_i is the indices of the selected predictions (in order)
  46. - index_j is the indices of the corresponding selected targets (in order)
  47. For each batch element, it holds:
  48. len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
  49. """
  50. with torch.no_grad():
  51. bs, num_queries = outputs["pred_logits"].shape[:2]
  52. # We flatten to compute the cost matrices in a batch
  53. out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
  54. out_bbox = outputs["pred_boxes"].flatten(0, 1)
  55. # Also concat the target labels and boxes
  56. tgt_ids = torch.cat([v["labels"] for v in targets]).to(out_prob.device)
  57. tgt_bbox = torch.cat([v["boxes"] for v in targets]).to(out_prob.device)
  58. # Compute the classification cost.
  59. alpha = 0.25
  60. gamma = 2.0
  61. neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
  62. pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
  63. cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
  64. # Compute the L1 cost between boxes
  65. out_delta = outputs["pred_deltas"].flatten(0, 1)
  66. out_bbox_old = outputs["pred_boxes_old"].flatten(0, 1)
  67. tgt_delta = bbox2delta(out_bbox_old, tgt_bbox)
  68. cost_bbox = torch.cdist(out_delta[:, None], tgt_delta, p=1).squeeze(1)
  69. # Compute the giou cost betwen boxes
  70. cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),
  71. box_cxcywh_to_xyxy(tgt_bbox)
  72. )
  73. # Final cost matrix
  74. C = self.cost_bbox * cost_bbox + \
  75. self.cost_class * cost_class + \
  76. self.cost_giou * cost_giou
  77. C = C.view(bs, num_queries, -1).cpu()
  78. sizes = [len(v["boxes"]) for v in targets]
  79. indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
  80. return [(torch.as_tensor(i, dtype=torch.int64), # batch index
  81. torch.as_tensor(j, dtype=torch.int64)) # query index
  82. for i, j in indices]