loss.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import copy
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. try:
  6. from .loss_utils import sigmoid_focal_loss
  7. from .loss_utils import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh, generalized_box_iou, bbox2delta
  8. from .loss_utils import is_dist_avail_and_initialized, get_world_size
  9. from .matcher import HungarianMatcher
  10. except:
  11. from loss_utils import sigmoid_focal_loss
  12. from loss_utils import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh, generalized_box_iou, bbox2delta
  13. from loss_utils import is_dist_avail_and_initialized, get_world_size
  14. from matcher import HungarianMatcher
  15. class Criterion(nn.Module):
  16. """ This class computes the loss for DETR.
  17. The process happens in two steps:
  18. 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
  19. 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
  20. """
  21. def __init__(self, cfg, num_classes=80, aux_loss=False):
  22. super().__init__()
  23. # ------------ Basic parameters ------------
  24. self.cfg = cfg
  25. self.num_classes = num_classes
  26. self.k_one2many = cfg['k_one2many']
  27. self.lambda_one2many = cfg['lambda_one2many']
  28. self.aux_loss = aux_loss
  29. self.losses = ['labels', 'boxes']
  30. # ------------- Focal loss -------------
  31. self.alpha = 0.25
  32. self.gamma = 2.0
  33. # ------------ Matcher ------------
  34. self.matcher = HungarianMatcher(cost_class = cfg['matcher_hpy']['cost_class'],
  35. cost_bbox = cfg['matcher_hpy']['cost_bbox'],
  36. cost_giou = cfg['matcher_hpy']['cost_giou']
  37. )
  38. # ------------- Loss weight -------------
  39. self.weight_dict = {'loss_cls': cfg['loss_coeff']['class'],
  40. 'loss_box': cfg['loss_coeff']['bbox'],
  41. 'loss_giou': cfg['loss_coeff']['giou']}
  42. if aux_loss:
  43. aux_weight_dict = {}
  44. for i in range(cfg['de_num_layers'] - 1):
  45. aux_weight_dict.update({k + f'_{i}': v for k, v in self.weight_dict.items()})
  46. self.weight_dict.update(aux_weight_dict)
  47. # ------------- One2many loss weight -------------
  48. if cfg['num_queries_one2many'] > 0:
  49. one2many_loss_weight = {}
  50. for k, v in self.weight_dict.items():
  51. one2many_loss_weight[k+"_one2many"] = v
  52. self.weight_dict.update(one2many_loss_weight)
  53. def _get_src_permutation_idx(self, indices):
  54. # permute predictions following indices
  55. batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
  56. src_idx = torch.cat([src for (src, _) in indices])
  57. return batch_idx, src_idx
  58. def _get_tgt_permutation_idx(self, indices):
  59. # permute targets following indices
  60. batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
  61. tgt_idx = torch.cat([tgt for (_, tgt) in indices])
  62. return batch_idx, tgt_idx
  63. def loss_labels(self, outputs, targets, indices, num_boxes):
  64. """Classification loss (NLL)
  65. targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
  66. """
  67. assert 'pred_logits' in outputs
  68. src_logits = outputs['pred_logits']
  69. # prepare class targets
  70. idx = self._get_src_permutation_idx(indices)
  71. target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]).to(src_logits.device)
  72. target_classes = torch.full(src_logits.shape[:2],
  73. self.num_classes,
  74. dtype=torch.int64,
  75. device=src_logits.device)
  76. target_classes[idx] = target_classes_o
  77. # to one-hot labels
  78. target_classes_onehot = torch.zeros([*src_logits.shape[:2], self.num_classes + 1],
  79. dtype=src_logits.dtype,
  80. layout=src_logits.layout,
  81. device=src_logits.device)
  82. target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
  83. target_classes_onehot = target_classes_onehot[..., :-1]
  84. # focal loss
  85. loss_cls = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, self.alpha, self.gamma)
  86. losses = {}
  87. losses['loss_cls'] = loss_cls * src_logits.shape[1]
  88. return losses
  89. def loss_boxes(self, outputs, targets, indices, num_boxes):
  90. """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
  91. targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
  92. The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
  93. """
  94. assert 'pred_boxes' in outputs
  95. # prepare bbox targets
  96. idx = self._get_src_permutation_idx(indices)
  97. src_boxes = outputs['pred_boxes'][idx]
  98. target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0).to(src_boxes.device)
  99. # compute L1 loss
  100. loss_bbox = F.l1_loss(src_boxes, box_xyxy_to_cxcywh(target_boxes), reduction='none')
  101. src_deltas = outputs["pred_deltas"][idx]
  102. src_boxes_old = outputs["pred_boxes_old"][idx]
  103. target_deltas = bbox2delta(src_boxes_old, target_boxes)
  104. loss_bbox = F.l1_loss(src_deltas, target_deltas, reduction="none")
  105. # compute GIoU loss
  106. bbox_giou = generalized_box_iou(box_cxcywh_to_xyxy(src_boxes),
  107. box_cxcywh_to_xyxy(target_boxes))
  108. loss_giou = 1 - torch.diag(bbox_giou)
  109. losses = {}
  110. losses['loss_box'] = loss_bbox.sum() / num_boxes
  111. losses['loss_giou'] = loss_giou.sum() / num_boxes
  112. return losses
  113. def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
  114. loss_map = {
  115. 'labels': self.loss_labels,
  116. 'boxes': self.loss_boxes,
  117. }
  118. assert loss in loss_map, f'do you really want to compute {loss} loss?'
  119. return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
  120. def compute_loss(self, outputs, targets):
  121. """ This performs the loss computation.
  122. Parameters:
  123. outputs: dict of tensors, see the output specification of the model for the format
  124. targets: list of dicts, such that len(targets) == batch_size.
  125. The expected keys in each dict depends on the losses applied, see each loss' doc
  126. """
  127. outputs_without_aux = {
  128. k: v
  129. for k, v in outputs.items()
  130. if k != "aux_outputs" and k != "enc_outputs"
  131. }
  132. # Retrieve the matching between the outputs of the last layer and the targets
  133. indices = self.matcher(outputs_without_aux, targets)
  134. # Compute the average number of target boxes accross all nodes, for normalization purposes
  135. num_boxes = sum(len(t["labels"]) for t in targets)
  136. num_boxes = torch.as_tensor(
  137. [num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device
  138. )
  139. if is_dist_avail_and_initialized():
  140. torch.distributed.all_reduce(num_boxes)
  141. num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
  142. # Compute all the requested losses
  143. losses = {}
  144. for loss in self.losses:
  145. kwargs = {}
  146. losses.update(
  147. self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs)
  148. )
  149. # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
  150. if "aux_outputs" in outputs:
  151. for i, aux_outputs in enumerate(outputs["aux_outputs"]):
  152. indices = self.matcher(aux_outputs, targets)
  153. for loss in self.losses:
  154. kwargs = {}
  155. l_dict = self.get_loss(
  156. loss, aux_outputs, targets, indices, num_boxes, **kwargs
  157. )
  158. l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
  159. losses.update(l_dict)
  160. if "enc_outputs" in outputs:
  161. enc_outputs = outputs["enc_outputs"]
  162. bin_targets = copy.deepcopy(targets)
  163. for bt in bin_targets:
  164. bt["labels"] = torch.zeros_like(bt["labels"])
  165. indices = self.matcher(enc_outputs, bin_targets)
  166. for loss in self.losses:
  167. kwargs = {}
  168. l_dict = self.get_loss(
  169. loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs
  170. )
  171. l_dict = {k + "_enc": v for k, v in l_dict.items()}
  172. losses.update(l_dict)
  173. return losses
  174. def forward(self, outputs, targets):
  175. # --------------------- One-to-one losses ---------------------
  176. outputs_one2one = {k: v for k, v in outputs.items() if "one2many" not in k}
  177. loss_dict = self.compute_loss(outputs_one2one, targets)
  178. # --------------------- One-to-many losses ---------------------
  179. outputs_one2many = {k[:-9]: v for k, v in outputs.items() if "one2many" in k}
  180. if len(outputs_one2many) > 0:
  181. # Copy targets
  182. multi_targets = copy.deepcopy(targets)
  183. for target in multi_targets:
  184. target["boxes"] = target["boxes"].repeat(self.k_one2many, 1)
  185. target["labels"] = target["labels"].repeat(self.k_one2many)
  186. # Compute one-to-many losses
  187. one2many_loss_dict = self.compute_loss(outputs_one2many, multi_targets)
  188. # add one2many losses in to the final loss_dict
  189. for k, v in one2many_loss_dict.items():
  190. if k + "_one2many" in loss_dict.keys():
  191. loss_dict[k + "_one2many"] += v * self.lambda_one2many
  192. else:
  193. loss_dict[k + "_one2many"] = v * self.lambda_one2many
  194. return loss_dict
  195. # build criterion
  196. def build_criterion(cfg, num_classes, aux_loss=True):
  197. criterion = Criterion(cfg, num_classes, aux_loss)
  198. return criterion