loss.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. """
  2. reference:
  3. https://github.com/facebookresearch/detr/blob/main/models/detr.py
  4. by lyuwenyu
  5. """
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. try:
  10. from .loss_utils import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
  11. from .loss_utils import is_dist_avail_and_initialized, get_world_size
  12. from .matcher import HungarianMatcher
  13. except:
  14. from loss_utils import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
  15. from loss_utils import is_dist_avail_and_initialized, get_world_size
  16. from matcher import HungarianMatcher
  17. # --------------- Criterion for RT-DETR ---------------
  18. def build_criterion(cfg, num_classes=80):
  19. matcher = HungarianMatcher(cfg['matcher_hpy'], alpha=0.25, gamma=2.0)
  20. weight_dict = {'loss_cls': cfg['loss_coeff']['class'],
  21. 'loss_box': cfg['loss_coeff']['bbox'],
  22. 'loss_giou': cfg['loss_coeff']['giou']}
  23. criterion = Criterion(matcher, weight_dict, num_classes=num_classes)
  24. return criterion
  25. class Criterion(nn.Module):
  26. """ This class computes the loss for DETR.
  27. The process happens in two steps:
  28. 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
  29. 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
  30. """
  31. def __init__(self, matcher, weight_dict, num_classes=80):
  32. """ Create the criterion.
  33. Parameters:
  34. num_classes: number of object categories, omitting the special no-object category
  35. matcher: module able to compute a matching between targets and proposals
  36. weight_dict: dict containing as key the names of the losses and as values their relative weight.
  37. eos_coef: relative classification weight applied to the no-object category
  38. losses: list of all the losses to be applied. See get_loss for list of available losses.
  39. """
  40. super().__init__()
  41. self.num_classes = num_classes
  42. self.matcher = matcher
  43. self.weight_dict = weight_dict
  44. self.losses = ['labels', 'boxes']
  45. self.alpha = 0.75 # For VFL
  46. self.gamma = 2.0
  47. def loss_labels(self, outputs, targets, indices, num_boxes):
  48. "Compute variable focal loss"
  49. assert 'pred_boxes' in outputs
  50. idx = self._get_src_permutation_idx(indices)
  51. # Compute IoU between pred and target
  52. src_boxes = outputs['pred_boxes'][idx]
  53. target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
  54. ious, _ = box_iou(box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes))
  55. ious = torch.diag(ious).detach()
  56. # One-hot class label
  57. src_logits = outputs['pred_logits']
  58. target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
  59. target_classes = torch.full(src_logits.shape[:2], self.num_classes,
  60. dtype=torch.int64, device=src_logits.device)
  61. target_classes[idx] = target_classes_o
  62. target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
  63. # Iou-aware class label
  64. target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype)
  65. target_score_o[idx] = ious.to(target_score_o.dtype)
  66. target_score = target_score_o.unsqueeze(-1) * target
  67. # Compute VFL
  68. pred_score = F.sigmoid(src_logits).detach()
  69. weight = self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score
  70. loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction='none')
  71. loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
  72. return {'loss_cls': loss}
  73. def loss_boxes(self, outputs, targets, indices, num_boxes):
  74. """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
  75. targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
  76. The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
  77. """
  78. assert 'pred_boxes' in outputs
  79. idx = self._get_src_permutation_idx(indices)
  80. src_boxes = outputs['pred_boxes'][idx]
  81. target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
  82. losses = {}
  83. loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
  84. losses['loss_box'] = loss_bbox.sum() / num_boxes
  85. loss_giou = 1 - torch.diag(generalized_box_iou(
  86. box_cxcywh_to_xyxy(src_boxes),
  87. box_cxcywh_to_xyxy(target_boxes)))
  88. losses['loss_giou'] = loss_giou.sum() / num_boxes
  89. return losses
  90. def _get_src_permutation_idx(self, indices):
  91. # permute predictions following indices
  92. batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
  93. src_idx = torch.cat([src for (src, _) in indices])
  94. return batch_idx, src_idx
  95. def _get_tgt_permutation_idx(self, indices):
  96. # permute targets following indices
  97. batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
  98. tgt_idx = torch.cat([tgt for (_, tgt) in indices])
  99. return batch_idx, tgt_idx
  100. def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
  101. loss_map = {
  102. 'boxes': self.loss_boxes,
  103. 'labels': self.loss_labels,
  104. }
  105. assert loss in loss_map, f'do you really want to compute {loss} loss?'
  106. return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
  107. def forward(self, outputs, targets):
  108. """ This performs the loss computation.
  109. Parameters:
  110. outputs: dict of tensors, see the output specification of the model for the format
  111. targets: list of dicts, such that len(targets) == batch_size.
  112. The expected keys in each dict depends on the losses applied, see each loss' doc
  113. """
  114. outputs_without_aux = {k: v for k, v in outputs.items() if 'aux' not in k}
  115. # Retrieve the matching between the outputs of the last layer and the targets
  116. indices = self.matcher(outputs_without_aux, targets)
  117. # Compute the average number of target boxes accross all nodes, for normalization purposes
  118. num_boxes = sum(len(t["labels"]) for t in targets)
  119. num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
  120. if is_dist_avail_and_initialized():
  121. torch.distributed.all_reduce(num_boxes)
  122. num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
  123. # Compute all the requested losses
  124. losses = {}
  125. for loss in self.losses:
  126. l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
  127. l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
  128. losses.update(l_dict)
  129. # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
  130. if 'aux_outputs' in outputs:
  131. for i, aux_outputs in enumerate(outputs['aux_outputs']):
  132. indices = self.matcher(aux_outputs, targets)
  133. for loss in self.losses:
  134. l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes)
  135. l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
  136. l_dict = {k + f'_aux_{i}': v for k, v in l_dict.items()}
  137. losses.update(l_dict)
  138. # In case of cdn auxiliary losses. For rtdetr
  139. if 'dn_aux_outputs' in outputs:
  140. assert 'dn_meta' in outputs, ''
  141. indices = self.get_cdn_matched_indices(outputs['dn_meta'], targets)
  142. num_boxes = num_boxes * outputs['dn_meta']['dn_num_group']
  143. for i, aux_outputs in enumerate(outputs['dn_aux_outputs']):
  144. # indices = self.matcher(aux_outputs, targets)
  145. for loss in self.losses:
  146. l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes)
  147. l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
  148. l_dict = {k + f'_dn_{i}': v for k, v in l_dict.items()}
  149. losses.update(l_dict)
  150. return losses
  151. @staticmethod
  152. def get_cdn_matched_indices(dn_meta, targets):
  153. '''get_cdn_matched_indices
  154. '''
  155. dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
  156. num_gts = [len(t['labels']) for t in targets]
  157. device = targets[0]['labels'].device
  158. dn_match_indices = []
  159. for i, num_gt in enumerate(num_gts):
  160. if num_gt > 0:
  161. gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device)
  162. gt_idx = gt_idx.tile(dn_num_group)
  163. assert len(dn_positive_idx[i]) == len(gt_idx)
  164. dn_match_indices.append((dn_positive_idx[i], gt_idx))
  165. else:
  166. dn_match_indices.append((torch.zeros(0, dtype=torch.int64, device=device), \
  167. torch.zeros(0, dtype=torch.int64, device=device)))
  168. return dn_match_indices