loss.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. from typing import Any
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from utils.box_ops import get_ious
  6. from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
  7. from .matcher import AlignedSimOTA
  8. class Criterion(object):
  9. def __init__(self, args, cfg, device, num_classes=80):
  10. self.args = args
  11. self.cfg = cfg
  12. self.device = device
  13. self.num_classes = num_classes
  14. self.max_epoch = args.max_epoch
  15. self.no_aug_epoch = args.no_aug_epoch
  16. self.aux_bbox_loss = False
  17. # --------------- Loss config ---------------
  18. self.loss_cls_weight = cfg['loss_cls_weight']
  19. self.loss_box_weight = cfg['loss_box_weight']
  20. # --------------- Matcher config ---------------
  21. self.matcher_hpy = cfg['matcher_hpy']
  22. self.matcher = AlignedSimOTA(soft_center_radius = self.matcher_hpy['soft_center_radius'],
  23. topk_candidates = self.matcher_hpy['topk_candidates'],
  24. num_classes = num_classes,
  25. )
  26. # -------------------- Basic loss functions --------------------
  27. def loss_classes(self, pred_cls, target, beta=2.0):
  28. # Quality FocalLoss
  29. """
  30. pred_cls: (torch.Tensor): [N, C]。
  31. target: (tuple([torch.Tensor], [torch.Tensor])): label -> (N,), score -> (N)
  32. """
  33. label, score = target
  34. pred_sigmoid = pred_cls.sigmoid()
  35. scale_factor = pred_sigmoid
  36. zerolabel = scale_factor.new_zeros(pred_cls.shape)
  37. ce_loss = F.binary_cross_entropy_with_logits(
  38. pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta)
  39. bg_class_ind = pred_cls.shape[-1]
  40. pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
  41. pos_label = label[pos].long()
  42. scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
  43. ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
  44. pred_cls[pos, pos_label], score[pos],
  45. reduction='none') * scale_factor.abs().pow(beta)
  46. return ce_loss
  47. def loss_bboxes(self, pred_box, gt_box):
  48. ious = get_ious(pred_box, gt_box, box_mode="xyxy", iou_type='giou')
  49. loss_box = 1.0 - ious
  50. return loss_box
  51. def loss_bboxes_aux(self, pred_reg, gt_box, anchors, stride_tensors):
  52. # xyxy -> cxcy&bwbh
  53. gt_cxcy = (gt_box[..., :2] + gt_box[..., 2:]) * 0.5
  54. gt_bwbh = gt_box[..., 2:] - gt_box[..., :2]
  55. # encode gt box
  56. gt_cxcy_encode = (gt_cxcy - anchors) / stride_tensors
  57. gt_bwbh_encode = torch.log(gt_bwbh / stride_tensors)
  58. gt_box_encode = torch.cat([gt_cxcy_encode, gt_bwbh_encode], dim=-1)
  59. # l1 loss
  60. loss_box_aux = F.l1_loss(pred_reg, gt_box_encode, reduction='none')
  61. return loss_box_aux
  62. # -------------------- Task loss functions --------------------
  63. def compute_det_loss(self, outputs, targets, epoch=0):
  64. """
  65. Input:
  66. outputs: (Dict) -> {
  67. 'pred_cls': (List[torch.Tensor] -> [B, M, Nc]),
  68. 'pred_reg': (List[torch.Tensor] -> [B, M, 4]),
  69. 'pred_box': (List[torch.Tensor] -> [B, M, 4]),
  70. 'strides': (List[Int])
  71. }
  72. target: (List[Dict]) [
  73. {'boxes': (torch.Tensor) -> [N, 4],
  74. 'labels': (torch.Tensor) -> [N,],
  75. ...}, ...
  76. ]
  77. Output:
  78. loss_dict: (Dict) -> {
  79. 'loss_cls': (torch.Tensor) It is a scalar.),
  80. 'loss_box': (torch.Tensor) It is a scalar.),
  81. 'loss_box_aux': (torch.Tensor) It is a scalar.),
  82. 'losses': (torch.Tensor) It is a scalar.),
  83. }
  84. """
  85. bs = outputs['pred_cls'][0].shape[0]
  86. device = outputs['pred_cls'][0].device
  87. fpn_strides = outputs['strides']
  88. anchors = outputs['anchors']
  89. # preds: [B, M, C]
  90. cls_preds = torch.cat(outputs['pred_cls'], dim=1)
  91. box_preds = torch.cat(outputs['pred_box'], dim=1)
  92. # --------------- label assignment ---------------
  93. cls_targets = []
  94. box_targets = []
  95. assign_metrics = []
  96. for batch_idx in range(bs):
  97. tgt_labels = targets[batch_idx]["labels"].to(device) # [N,]
  98. tgt_bboxes = targets[batch_idx]["boxes"].to(device) # [N, 4]
  99. assigned_result = self.matcher(fpn_strides=fpn_strides,
  100. anchors=anchors,
  101. pred_cls=cls_preds[batch_idx].detach(),
  102. pred_box=box_preds[batch_idx].detach(),
  103. gt_labels=tgt_labels,
  104. gt_bboxes=tgt_bboxes
  105. )
  106. cls_targets.append(assigned_result['assigned_labels'])
  107. box_targets.append(assigned_result['assigned_bboxes'])
  108. assign_metrics.append(assigned_result['assign_metrics'])
  109. # List[B, M, C] -> Tensor[BM, C]
  110. cls_targets = torch.cat(cls_targets, dim=0)
  111. box_targets = torch.cat(box_targets, dim=0)
  112. assign_metrics = torch.cat(assign_metrics, dim=0)
  113. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  114. bg_class_ind = self.num_classes
  115. pos_inds = ((cls_targets >= 0) & (cls_targets < bg_class_ind)).nonzero().squeeze(1)
  116. num_fgs = assign_metrics.sum()
  117. if is_dist_avail_and_initialized():
  118. torch.distributed.all_reduce(num_fgs)
  119. num_fgs = (num_fgs / get_world_size()).clamp(1.0).item()
  120. # ------------------ Classification loss ------------------
  121. cls_preds = cls_preds.view(-1, self.num_classes)
  122. loss_cls = self.loss_classes(cls_preds, (cls_targets, assign_metrics))
  123. loss_cls = loss_cls.sum() / num_fgs
  124. # ------------------ Regression loss ------------------
  125. box_preds_pos = box_preds.view(-1, 4)[pos_inds]
  126. box_targets_pos = box_targets[pos_inds]
  127. loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos)
  128. loss_box = loss_box.sum() / num_fgs
  129. # total loss
  130. losses = self.loss_cls_weight * loss_cls + \
  131. self.loss_box_weight * loss_box
  132. # ------------------ Aux regression loss ------------------
  133. loss_box_aux = None
  134. if epoch >= (self.max_epoch - self.no_aug_epoch - 1):
  135. ## reg_preds
  136. reg_preds = torch.cat(outputs['pred_reg'], dim=1)
  137. reg_preds_pos = reg_preds.view(-1, 4)[pos_inds]
  138. ## anchor tensors
  139. anchors_tensors = torch.cat(outputs['anchors'], dim=0)[None].repeat(bs, 1, 1)
  140. anchors_tensors_pos = anchors_tensors.view(-1, 2)[pos_inds]
  141. ## stride tensors
  142. stride_tensors = torch.cat(outputs['stride_tensors'], dim=0)[None].repeat(bs, 1, 1)
  143. stride_tensors_pos = stride_tensors.view(-1, 1)[pos_inds]
  144. ## aux loss
  145. loss_box_aux = self.loss_bboxes_aux(reg_preds_pos, box_targets_pos, anchors_tensors_pos, stride_tensors_pos)
  146. loss_box_aux = loss_box_aux.sum() / num_fgs
  147. losses += loss_box_aux
  148. # Loss dict
  149. if loss_box_aux is None:
  150. loss_dict = dict(
  151. loss_cls = loss_cls,
  152. loss_box = loss_box,
  153. losses = losses
  154. )
  155. else:
  156. loss_dict = dict(
  157. loss_cls = loss_cls,
  158. loss_box = loss_box,
  159. loss_box_aux = loss_box_aux,
  160. losses = losses
  161. )
  162. return loss_dict
  163. def compute_seg_loss(self, outputs, targets, epoch=0):
  164. """
  165. Input:
  166. outputs: (Dict) -> {
  167. 'pred_cls': (List[torch.Tensor] -> [B, M, Nc]),
  168. 'pred_reg': (List[torch.Tensor] -> [B, M, 4]),
  169. 'pred_box': (List[torch.Tensor] -> [B, M, 4]),
  170. 'strides': (List[Int])
  171. }
  172. target: (List[Dict]) [
  173. {'boxes': (torch.Tensor) -> [N, 4],
  174. 'labels': (torch.Tensor) -> [N,],
  175. ...}, ...
  176. ]
  177. Output:
  178. loss_dict: (Dict) -> {
  179. 'loss_cls': (torch.Tensor) It is a scalar.),
  180. 'loss_box': (torch.Tensor) It is a scalar.),
  181. 'loss_box_aux': (torch.Tensor) It is a scalar.),
  182. 'losses': (torch.Tensor) It is a scalar.),
  183. }
  184. """
  185. def compute_pos_loss(self, outputs, targets, epoch=0):
  186. """
  187. Input:
  188. outputs: (Dict) -> {
  189. 'pred_cls': (List[torch.Tensor] -> [B, M, Nc]),
  190. 'pred_reg': (List[torch.Tensor] -> [B, M, 4]),
  191. 'pred_box': (List[torch.Tensor] -> [B, M, 4]),
  192. 'strides': (List[Int])
  193. }
  194. target: (List[Dict]) [
  195. {'boxes': (torch.Tensor) -> [N, 4],
  196. 'labels': (torch.Tensor) -> [N,],
  197. ...}, ...
  198. ]
  199. Output:
  200. loss_dict: (Dict) -> {
  201. 'loss_cls': (torch.Tensor) It is a scalar.),
  202. 'loss_box': (torch.Tensor) It is a scalar.),
  203. 'loss_box_aux': (torch.Tensor) It is a scalar.),
  204. 'losses': (torch.Tensor) It is a scalar.),
  205. }
  206. """
  207. def __call__(self, outputs, targets, epoch=0, task='det'):
  208. # -------------- Detection loss --------------
  209. det_loss_dict = None
  210. if outputs['det_outputs'] is not None:
  211. det_loss_dict = self.compute_det_loss(outputs['det_outputs'], targets, epoch)
  212. # -------------- Segmentation loss --------------
  213. seg_loss_dict = None
  214. if outputs['seg_outputs'] is not None:
  215. seg_loss_dict = self.compute_seg_loss(outputs['seg_outputs'], targets, epoch)
  216. # -------------- Human pose loss --------------
  217. pos_loss_dict = None
  218. if outputs['pos_outputs'] is not None:
  219. pos_loss_dict = self.compute_seg_loss(outputs['pos_outputs'], targets, epoch)
  220. # Loss dict
  221. if task == 'det':
  222. return det_loss_dict
  223. if task == 'det_seg':
  224. return {'det_loss_dict': det_loss_dict,
  225. 'seg_loss_dict': seg_loss_dict}
  226. if task == 'det_pos':
  227. return {'det_loss_dict': det_loss_dict,
  228. 'pos_loss_dict': pos_loss_dict}
  229. if task == 'det_seg_pos':
  230. return {'det_loss_dict': det_loss_dict,
  231. 'seg_loss_dict': seg_loss_dict,
  232. 'pos_loss_dict': pos_loss_dict}
  233. def build_criterion(args, cfg, device, num_classes):
  234. criterion = Criterion(args, cfg, device, num_classes)
  235. return criterion
  236. if __name__ == "__main__":
  237. pass