matcher.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # ------------------------------------------------------------------------
  2. # Plain-DETR
  3. # Copyright (c) 2023 Xi'an Jiaotong University & Microsoft Research Asia.
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # ------------------------------------------------------------------------
  6. # Deformable DETR
  7. # Copyright (c) 2020 SenseTime. All Rights Reserved.
  8. # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
  9. # ------------------------------------------------------------------------
  10. # Modified from DETR (https://github.com/facebookresearch/detr)
  11. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  12. # ------------------------------------------------------------------------
  13. """
  14. Modules to compute the matching cost and solve the corresponding LSAP.
  15. """
  16. import torch
  17. from scipy.optimize import linear_sum_assignment
  18. from torch import nn
  19. try:
  20. from .loss_utils import box_cxcywh_to_xyxy, generalized_box_iou, bbox2delta
  21. except:
  22. from loss_utils import box_cxcywh_to_xyxy, generalized_box_iou, bbox2delta
  23. class HungarianMatcher(nn.Module):
  24. """This class computes an assignment between the targets and the predictions of the network
  25. For efficiency reasons, the targets don't include the no_object. Because of this, in general,
  26. there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
  27. while the others are un-matched (and thus treated as non-objects).
  28. """
  29. def __init__(self,
  30. cost_class: float = 1,
  31. cost_bbox: float = 1,
  32. cost_giou: float = 1,
  33. ):
  34. """Creates the matcher
  35. Params:
  36. cost_class: This is the relative weight of the classification error in the matching cost
  37. cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
  38. cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
  39. """
  40. super().__init__()
  41. self.cost_class = cost_class
  42. self.cost_bbox = cost_bbox
  43. self.cost_giou = cost_giou
  44. assert (
  45. cost_class != 0 or cost_bbox != 0 or cost_giou != 0
  46. ), "all costs cant be 0"
  47. def forward(self, outputs, targets):
  48. """ Performs the matching
  49. Params:
  50. outputs: This is a dict that contains at least these entries:
  51. "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
  52. "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
  53. targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
  54. "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
  55. objects in the target) containing the class labels
  56. "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
  57. Returns:
  58. A list of size batch_size, containing tuples of (index_i, index_j) where:
  59. - index_i is the indices of the selected predictions (in order)
  60. - index_j is the indices of the corresponding selected targets (in order)
  61. For each batch element, it holds:
  62. len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
  63. """
  64. with torch.no_grad():
  65. bs, num_queries = outputs["pred_logits"].shape[:2]
  66. # We flatten to compute the cost matrices in a batch
  67. out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
  68. out_bbox = outputs["pred_boxes"].flatten(0, 1)
  69. # Also concat the target labels and boxes
  70. tgt_ids = torch.cat([v["labels"] for v in targets]).to(out_prob.device)
  71. tgt_bbox = torch.cat([v["boxes"] for v in targets]).to(out_prob.device)
  72. # Compute the classification cost.
  73. alpha = 0.25
  74. gamma = 2.0
  75. neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
  76. pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
  77. cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
  78. # Compute the L1 cost between boxes
  79. out_delta = outputs["pred_deltas"].flatten(0, 1)
  80. out_bbox_old = outputs["pred_boxes_old"].flatten(0, 1)
  81. tgt_delta = bbox2delta(out_bbox_old, tgt_bbox)
  82. cost_bbox = torch.cdist(out_delta[:, None], tgt_delta, p=1).squeeze(1)
  83. # Compute the giou cost betwen boxes
  84. cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),
  85. box_cxcywh_to_xyxy(tgt_bbox)
  86. )
  87. # Final cost matrix
  88. C = self.cost_bbox * cost_bbox + \
  89. self.cost_class * cost_class + \
  90. self.cost_giou * cost_giou
  91. C = C.view(bs, num_queries, -1).cpu()
  92. sizes = [len(v["boxes"]) for v in targets]
  93. indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
  94. return [(torch.as_tensor(i, dtype=torch.int64), # batch index
  95. torch.as_tensor(j, dtype=torch.int64)) # query index
  96. for i, j in indices]