loss.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import copy
  5. from .matcher import build_matcher
  6. from utils.misc import sigmoid_focal_loss
  7. from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
  8. from utils.distributed_utils import is_dist_avail_and_initialized, get_world_size
  9. class Criterion(nn.Module):
  10. """ This class computes the loss for DETR.
  11. The process happens in two steps:
  12. 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
  13. 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
  14. """
  15. def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25):
  16. """ Create the criterion.
  17. Parameters:
  18. num_classes: number of object categories, omitting the special no-object category
  19. matcher: module able to compute a matching between targets and proposals
  20. weight_dict: dict containing as key the names of the losses and as values their relative weight.
  21. eos_coef: relative classification weight applied to the no-object category
  22. losses: list of all the losses to be applied. See get_loss for list of available losses.
  23. """
  24. super().__init__()
  25. self.num_classes = num_classes
  26. self.matcher = matcher
  27. self.weight_dict = weight_dict
  28. self.losses = losses
  29. self.focal_alpha = focal_alpha
  30. def _get_src_permutation_idx(self, indices):
  31. # permute predictions following indices
  32. batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
  33. src_idx = torch.cat([src for (src, _) in indices])
  34. return batch_idx, src_idx
  35. def _get_tgt_permutation_idx(self, indices):
  36. # permute targets following indices
  37. batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
  38. tgt_idx = torch.cat([tgt for (_, tgt) in indices])
  39. return batch_idx, tgt_idx
  40. def loss_labels(self, outputs, targets, indices, num_boxes):
  41. """Classification loss (NLL)
  42. targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
  43. """
  44. assert 'pred_logits' in outputs
  45. src_logits = outputs['pred_logits']
  46. idx = self._get_src_permutation_idx(indices)
  47. target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]).to(src_logits.device)
  48. target_classes = torch.full(src_logits.shape[:2], self.num_classes,
  49. dtype=torch.int64, device=src_logits.device)
  50. target_classes[idx] = target_classes_o
  51. target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
  52. dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
  53. target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
  54. target_classes_onehot = target_classes_onehot[:, :, :-1]
  55. loss_cls = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * \
  56. src_logits.shape[1]
  57. losses = {'loss_cls': loss_cls}
  58. return losses
  59. def loss_boxes(self, outputs, targets, indices, num_boxes):
  60. """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
  61. targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
  62. The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
  63. """
  64. assert 'pred_boxes' in outputs
  65. idx = self._get_src_permutation_idx(indices)
  66. src_boxes = outputs['pred_boxes'][idx]
  67. target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0).to(src_boxes.device)
  68. loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
  69. losses = {}
  70. losses['loss_bbox'] = loss_bbox.sum() / num_boxes
  71. loss_giou = 1 - torch.diag(generalized_box_iou(
  72. box_cxcywh_to_xyxy(src_boxes),
  73. box_cxcywh_to_xyxy(target_boxes)))
  74. losses['loss_giou'] = loss_giou.sum() / num_boxes
  75. return losses
  76. def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
  77. loss_map = {
  78. 'labels': self.loss_labels,
  79. 'boxes': self.loss_boxes,
  80. }
  81. assert loss in loss_map, f'do you really want to compute {loss} loss?'
  82. return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
  83. def forward(self, outputs, targets):
  84. """ This performs the loss computation.
  85. Parameters:
  86. outputs: dict of tensors, see the output specification of the model for the format
  87. targets: list of dicts, such that len(targets) == batch_size.
  88. The expected keys in each dict depends on the losses applied, see each loss' doc
  89. """
  90. outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
  91. # Retrieve the matching between the outputs of the last layer and the targets
  92. indices = self.matcher(outputs_without_aux, targets)
  93. # Compute the average number of target boxes accross all nodes, for normalization purposes
  94. num_boxes = sum(len(t["labels"]) for t in targets)
  95. num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
  96. if is_dist_avail_and_initialized():
  97. torch.distributed.all_reduce(num_boxes)
  98. num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
  99. # Compute all the requested losses
  100. losses = {}
  101. for loss in self.losses:
  102. losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
  103. # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
  104. if 'aux_outputs' in outputs:
  105. for i, aux_outputs in enumerate(outputs['aux_outputs']):
  106. indices = self.matcher(aux_outputs, targets)
  107. for loss in self.losses:
  108. kwargs = {}
  109. l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
  110. l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
  111. losses.update(l_dict)
  112. weight_dict = self.weight_dict
  113. total_loss = sum(losses[k] * weight_dict[k] for k in losses.keys() if k in weight_dict)
  114. losses['losses'] = total_loss
  115. return losses
  116. # build criterion
  117. def build_criterion(cfg, num_classes, aux_loss=False):
  118. matcher = build_matcher(cfg)
  119. weight_dict = {'loss_cls': cfg['loss_cls_weight'],
  120. 'loss_bbox': cfg['loss_box_weight'],
  121. 'loss_giou': cfg['loss_giou_weight']}
  122. # TODO this is a hack
  123. if aux_loss:
  124. aux_weight_dict = {}
  125. for i in range(cfg['num_decoder_layers'] - 1):
  126. aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
  127. weight_dict.update(aux_weight_dict)
  128. losses = ['labels', 'boxes']
  129. criterion = Criterion(
  130. num_classes=num_classes,
  131. matcher=matcher,
  132. weight_dict=weight_dict,
  133. losses=losses,
  134. focal_alpha=cfg['focal_alpha'])
  135. return criterion