criterion.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. """
  2. reference:
  3. https://github.com/facebookresearch/detr/blob/main/models/detr.py
  4. by lyuwenyu
  5. """
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
  10. from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
  11. from .matcher import HungarianMatcher
  12. # --------------- Criterion for DETR ---------------
  13. class SetCriterion(nn.Module):
  14. def __init__(self, cfg):
  15. super().__init__()
  16. self.num_classes = cfg.num_classes
  17. self.losses = ['labels', 'boxes']
  18. self.eos_coef = 0.1
  19. # -------- Loss weights --------
  20. self.weight_dict = {'loss_cls': cfg.loss_cls,
  21. 'loss_box': cfg.loss_box,
  22. 'loss_giou': cfg.loss_giou}
  23. for i in range(cfg.num_dec_layers - 1):
  24. self.weight_dict.update({k + f'_aux_{i}': v for k, v in self.weight_dict.items()})
  25. empty_weight = torch.ones(self.num_classes + 1)
  26. empty_weight[-1] = self.eos_coef
  27. self.register_buffer('empty_weight', empty_weight)
  28. # -------- Matcher --------
  29. matcher_hpy = cfg.matcher_hpy
  30. self.matcher = HungarianMatcher(matcher_hpy['cost_class'], matcher_hpy['cost_bbox'], matcher_hpy['cost_giou'])
  31. def loss_labels(self, outputs, targets, indices, num_boxes):
  32. assert 'pred_logits' in outputs
  33. src_logits = outputs['pred_logits']
  34. target_classes = torch.full(src_logits.shape[:2], self.num_classes,
  35. dtype=torch.int64, device=src_logits.device)
  36. loss_cls = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
  37. return {'loss_cls': loss_cls.sum() / num_boxes}
  38. def loss_boxes(self, outputs, targets, indices, num_boxes):
  39. assert 'pred_boxes' in outputs
  40. idx = self._get_src_permutation_idx(indices)
  41. src_boxes = outputs['pred_boxes'][idx]
  42. target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
  43. loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
  44. loss_giou = 1 - torch.diag(generalized_box_iou(box_cxcywh_to_xyxy(src_boxes),
  45. box_cxcywh_to_xyxy(target_boxes)))
  46. return {'loss_box': loss_bbox.sum() / num_boxes,
  47. 'loss_giou': loss_giou.sum() / num_boxes}
  48. def _get_src_permutation_idx(self, indices):
  49. # permute predictions following indices
  50. batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
  51. src_idx = torch.cat([src for (src, _) in indices])
  52. return batch_idx, src_idx
  53. def _get_tgt_permutation_idx(self, indices):
  54. # permute targets following indices
  55. batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
  56. tgt_idx = torch.cat([tgt for (_, tgt) in indices])
  57. return batch_idx, tgt_idx
  58. def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
  59. loss_map = {
  60. 'boxes': self.loss_boxes,
  61. 'labels': self.loss_labels,
  62. }
  63. return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
  64. def forward(self, outputs, targets):
  65. outputs_without_aux = {k: v for k, v in outputs.items() if 'aux' not in k}
  66. # Retrieve the matching between the outputs of the last layer and the targets
  67. indices = self.matcher(outputs_without_aux, targets)
  68. # Compute the average number of target boxes accross all nodes, for normalization purposes
  69. num_boxes = sum(len(t["labels"]) for t in targets)
  70. num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
  71. if is_dist_avail_and_initialized():
  72. torch.distributed.all_reduce(num_boxes)
  73. num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
  74. # Compute all the requested losses
  75. loss_dict = {}
  76. for loss in self.losses:
  77. l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
  78. l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
  79. loss_dict.update(l_dict)
  80. # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
  81. if 'aux_outputs' in outputs:
  82. for i, aux_outputs in enumerate(outputs['aux_outputs']):
  83. indices = self.matcher(aux_outputs, targets)
  84. for loss in self.losses:
  85. l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes)
  86. l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
  87. l_dict = {k + f'_aux_{i}': v for k, v in l_dict.items()}
  88. loss_dict.update(l_dict)
  89. # Total loss
  90. loss_dict["losses"] = sum(loss_dict.values())
  91. return loss_dict