loss.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from utils.box_ops import bbox2dist, bbox_iou
  5. from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
  6. from .matcher import SimOTA
  7. class Criterion(object):
  8. def __init__(self, cfg, device, num_classes=80):
  9. # --------------- Basic parameters ---------------
  10. self.cfg = cfg
  11. self.device = device
  12. self.num_classes = num_classes
  13. self.reg_max = cfg['reg_max']
  14. # --------------- Loss config ---------------
  15. self.loss_cls_weight = cfg['loss_cls_weight']
  16. self.loss_box_weight = cfg['loss_box_weight']
  17. self.loss_dfl_weight = cfg['loss_dfl_weight']
  18. # --------------- Matcher config ---------------
  19. self.matcher_hpy = cfg['matcher_hpy']
  20. self.matcher = SimOTA(center_sampling_radius = self.matcher_hpy['center_sampling_radius'],
  21. topk_candidate = self.matcher_hpy['topk_candidate'],
  22. num_classes = num_classes)
  23. def loss_classes(self, pred_cls, gt_score):
  24. # compute bce loss
  25. loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_score, reduction='none')
  26. return loss_cls
  27. def loss_bboxes(self, pred_box, gt_box, bbox_weight):
  28. # regression loss
  29. ious = bbox_iou(pred_box, gt_box, xywh=False, CIoU=True)
  30. loss_box = (1.0 - ious.squeeze(-1)) * bbox_weight
  31. return loss_box
  32. def loss_dfl(self, pred_reg, gt_box, anchor, stride, bbox_weight=None):
  33. # rescale coords by stride
  34. gt_box_s = gt_box / stride
  35. anchor_s = anchor / stride
  36. # compute deltas
  37. gt_ltrb_s = bbox2dist(anchor_s, gt_box_s, self.cfg['reg_max'] - 1)
  38. gt_left = gt_ltrb_s.to(torch.long)
  39. gt_right = gt_left + 1
  40. weight_left = gt_right.to(torch.float) - gt_ltrb_s
  41. weight_right = 1 - weight_left
  42. # loss left
  43. loss_left = F.cross_entropy(
  44. pred_reg.view(-1, self.cfg['reg_max']),
  45. gt_left.view(-1),
  46. reduction='none').view(gt_left.shape) * weight_left
  47. # loss right
  48. loss_right = F.cross_entropy(
  49. pred_reg.view(-1, self.cfg['reg_max']),
  50. gt_right.view(-1),
  51. reduction='none').view(gt_left.shape) * weight_right
  52. loss_dfl = (loss_left + loss_right).mean(-1)
  53. if bbox_weight is not None:
  54. loss_dfl *= bbox_weight
  55. return loss_dfl
  56. def __call__(self, outputs, targets, epoch=0):
  57. """
  58. outputs['pred_cls']: List(Tensor) [B, M, C]
  59. outputs['pred_reg']: List(Tensor) [B, M, 4*(reg_max+1)]
  60. outputs['pred_box']: List(Tensor) [B, M, 4]
  61. outputs['anchors']: List(Tensor) [M, 2]
  62. outputs['strides']: List(Int) [8, 16, 32] output stride
  63. outputs['stride_tensor']: List(Tensor) [M, 1]
  64. targets: (List) [dict{'boxes': [...],
  65. 'labels': [...],
  66. 'orig_size': ...}, ...]
  67. """
  68. bs = outputs['pred_cls'][0].shape[0]
  69. device = outputs['pred_cls'][0].device
  70. anchors = outputs['anchors']
  71. fpn_strides = outputs['strides']
  72. # preds: [B, M, C]
  73. cls_preds = torch.cat(outputs['pred_cls'], dim=1)
  74. reg_preds = torch.cat(outputs['pred_reg'], dim=1)
  75. box_preds = torch.cat(outputs['pred_box'], dim=1)
  76. num_anchors = box_preds.shape[1]
  77. # --------------- label assignment ---------------
  78. cls_targets = []
  79. box_targets = []
  80. fg_masks = []
  81. for batch_idx in range(bs):
  82. tgt_labels = targets[batch_idx]["labels"].to(device)
  83. tgt_bboxes = targets[batch_idx]["boxes"].to(device)
  84. # check target
  85. if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
  86. # There is no valid gt
  87. cls_target = cls_preds.new_zeros((num_anchors, self.num_classes))
  88. box_target = cls_preds.new_zeros((0, 4))
  89. fg_mask = cls_preds.new_zeros(num_anchors).bool()
  90. else:
  91. (
  92. fg_mask,
  93. assigned_labels,
  94. assigned_ious,
  95. assigned_indexs
  96. ) = self.matcher(
  97. fpn_strides = fpn_strides,
  98. anchors = anchors,
  99. pred_cls = cls_preds[batch_idx],
  100. pred_box = box_preds[batch_idx],
  101. tgt_labels = tgt_labels,
  102. tgt_bboxes = tgt_bboxes
  103. )
  104. # prepare cls targets
  105. assigned_labels = F.one_hot(assigned_labels.long(), self.num_classes)
  106. assigned_labels = assigned_labels * assigned_ious.unsqueeze(-1)
  107. cls_target = assigned_labels.new_zeros((num_anchors, self.num_classes))
  108. cls_target[fg_mask] = assigned_labels
  109. # prepare box targets
  110. box_target = tgt_bboxes[assigned_indexs]
  111. cls_targets.append(cls_target)
  112. box_targets.append(box_target)
  113. fg_masks.append(fg_mask)
  114. cls_targets = torch.cat(cls_targets, 0)
  115. box_targets = torch.cat(box_targets, 0)
  116. fg_masks = torch.cat(fg_masks, 0)
  117. num_fgs = cls_targets.sum()
  118. # Average loss normalizer across all the GPUs
  119. if is_dist_avail_and_initialized():
  120. torch.distributed.all_reduce(num_fgs)
  121. num_fgs = (num_fgs / get_world_size()).clamp(1.0)
  122. # ------------------ Classification loss ------------------
  123. cls_preds = cls_preds.view(-1, self.num_classes)
  124. loss_cls = self.loss_classes(cls_preds, cls_targets)
  125. loss_cls = loss_cls.sum() / num_fgs
  126. # ------------------ Regression loss ------------------
  127. box_preds_pos = box_preds.view(-1, 4)[fg_masks]
  128. bbox_weight = cls_targets[fg_masks].sum(-1)
  129. loss_box = self.loss_bboxes(box_preds_pos, box_targets, bbox_weight)
  130. loss_box = loss_box.sum() / num_fgs
  131. # ------------------ Distribution focal loss ------------------
  132. ## process anchors
  133. anchors = torch.cat(anchors, dim=0)
  134. anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
  135. ## process stride tensors
  136. strides = torch.cat(outputs['stride_tensor'], dim=0)
  137. strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
  138. ## fg preds
  139. reg_preds_pos = reg_preds.view(-1, 4*self.cfg['reg_max'])[fg_masks]
  140. anchors_pos = anchors[fg_masks]
  141. strides_pos = strides[fg_masks]
  142. ## compute dfl
  143. loss_dfl = self.loss_dfl(reg_preds_pos, box_targets, anchors_pos, strides_pos, bbox_weight)
  144. loss_dfl = loss_dfl.sum() / num_fgs
  145. # total loss
  146. losses = loss_cls * self.loss_cls_weight + loss_box * self.loss_box_weight + loss_dfl * self.loss_dfl_weight
  147. loss_dict = dict(
  148. loss_cls = loss_cls,
  149. loss_box = loss_box,
  150. loss_dfl = loss_dfl,
  151. losses = losses
  152. )
  153. return loss_dict
  154. def build_criterion(cfg, device, num_classes):
  155. criterion = Criterion(
  156. cfg=cfg,
  157. device=device,
  158. num_classes=num_classes
  159. )
  160. return criterion
  161. if __name__ == "__main__":
  162. pass