matcher.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. # https://github.com/facebookresearch/detr
  3. import torch
  4. import torch.nn as nn
  5. from scipy.optimize import linear_sum_assignment
  6. from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
  7. class HungarianMatcher(nn.Module):
  8. def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
  9. super().__init__()
  10. self.cost_class = cost_class
  11. self.cost_bbox = cost_bbox
  12. self.cost_giou = cost_giou
  13. @torch.no_grad()
  14. def forward(self, outputs, targets):
  15. bs, num_queries = outputs["pred_logits"].shape[:2]
  16. # [B * num_queries, C] = [N, C]
  17. out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)
  18. out_bbox = outputs["pred_boxes"].flatten(0, 1)
  19. # [M,] where M is number of all targets in this batch
  20. tgt_ids = torch.cat([v["labels"] for v in targets])
  21. # [M, 4]
  22. tgt_bbox = torch.cat([v["boxes"] for v in targets])
  23. # [N, M]
  24. cost_class = -out_prob[:, tgt_ids]
  25. # [N, M]
  26. cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
  27. # [N, M]
  28. cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
  29. # Final cost matrix: [N, M]
  30. C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
  31. # [N, M] -> [B, num_queries, M]
  32. C = C.view(bs, num_queries, -1).cpu()
  33. # Optimziee cost
  34. sizes = [len(v["boxes"]) for v in targets]
  35. indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
  36. return [(torch.as_tensor(i, dtype=torch.int64), # tgt indexes
  37. torch.as_tensor(j, dtype=torch.int64)) # pred indexes
  38. for i, j in indices]