loss.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import torch
  2. import torch.nn.functional as F
  3. from .matcher import SimOTA
  4. from utils.box_ops import get_ious
  5. from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
  6. class Criterion(object):
  7. def __init__(self,
  8. args,
  9. cfg,
  10. device,
  11. num_classes=80):
  12. self.args = args
  13. self.cfg = cfg
  14. self.device = device
  15. self.num_classes = num_classes
  16. self.max_epoch = args.max_epoch
  17. self.no_aug_epoch = args.no_aug_epoch
  18. self.aux_bbox_loss = False
  19. # loss weight
  20. self.loss_obj_weight = cfg['loss_obj_weight']
  21. self.loss_cls_weight = cfg['loss_cls_weight']
  22. self.loss_box_weight = cfg['loss_box_weight']
  23. # matcher
  24. matcher_config = cfg['matcher']
  25. self.matcher = SimOTA(
  26. num_classes=num_classes,
  27. center_sampling_radius=matcher_config['center_sampling_radius'],
  28. topk_candidate=matcher_config['topk_candicate']
  29. )
  30. def loss_objectness(self, pred_obj, gt_obj):
  31. loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
  32. return loss_obj
  33. def loss_classes(self, pred_cls, gt_label):
  34. loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
  35. return loss_cls
  36. def loss_bboxes(self, pred_box, gt_box):
  37. # regression loss
  38. ious = get_ious(pred_box, gt_box, "xyxy", 'giou')
  39. loss_box = 1.0 - ious
  40. return loss_box
  41. def loss_bboxes_aux(self, pred_reg, gt_box, anchors, stride_tensors):
  42. # xyxy -> cxcy&bwbh
  43. gt_cxcy = (gt_box[..., :2] + gt_box[..., 2:]) * 0.5
  44. gt_bwbh = gt_box[..., 2:] - gt_box[..., :2]
  45. # encode gt box
  46. gt_cxcy_encode = (gt_cxcy - anchors) / stride_tensors
  47. gt_bwbh_encode = torch.log(gt_bwbh / stride_tensors)
  48. gt_box_encode = torch.cat([gt_cxcy_encode, gt_bwbh_encode], dim=-1)
  49. # l1 loss
  50. loss_box_aux = F.l1_loss(pred_reg, gt_box_encode, reduction='none')
  51. return loss_box_aux
  52. def __call__(self, outputs, targets, epoch=0):
  53. """
  54. outputs['pred_obj']: List(Tensor) [B, M, 1]
  55. outputs['pred_cls']: List(Tensor) [B, M, C]
  56. outputs['pred_box']: List(Tensor) [B, M, 4]
  57. outputs['pred_box']: List(Tensor) [B, M, 4]
  58. outputs['strides']: List(Int) [8, 16, 32] output stride
  59. targets: (List) [dict{'boxes': [...],
  60. 'labels': [...],
  61. 'orig_size': ...}, ...]
  62. """
  63. bs = outputs['pred_cls'][0].shape[0]
  64. device = outputs['pred_cls'][0].device
  65. fpn_strides = outputs['strides']
  66. anchors = outputs['anchors']
  67. # preds: [B, M, C]
  68. obj_preds = torch.cat(outputs['pred_obj'], dim=1)
  69. cls_preds = torch.cat(outputs['pred_cls'], dim=1)
  70. box_preds = torch.cat(outputs['pred_box'], dim=1)
  71. # label assignment
  72. cls_targets = []
  73. box_targets = []
  74. obj_targets = []
  75. fg_masks = []
  76. for batch_idx in range(bs):
  77. tgt_labels = targets[batch_idx]["labels"].to(device)
  78. tgt_bboxes = targets[batch_idx]["boxes"].to(device)
  79. # check target
  80. if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
  81. num_anchors = sum([ab.shape[0] for ab in anchors])
  82. # There is no valid gt
  83. cls_target = obj_preds.new_zeros((0, self.num_classes))
  84. box_target = obj_preds.new_zeros((0, 4))
  85. obj_target = obj_preds.new_zeros((num_anchors, 1))
  86. fg_mask = obj_preds.new_zeros(num_anchors).bool()
  87. else:
  88. (
  89. fg_mask,
  90. assigned_labels,
  91. assigned_ious,
  92. assigned_indexs
  93. ) = self.matcher(
  94. fpn_strides = fpn_strides,
  95. anchors = anchors,
  96. pred_obj = obj_preds[batch_idx],
  97. pred_cls = cls_preds[batch_idx],
  98. pred_box = box_preds[batch_idx],
  99. tgt_labels = tgt_labels,
  100. tgt_bboxes = tgt_bboxes
  101. )
  102. obj_target = fg_mask.unsqueeze(-1)
  103. cls_target = F.one_hot(assigned_labels.long(), self.num_classes)
  104. cls_target = cls_target * assigned_ious.unsqueeze(-1)
  105. box_target = tgt_bboxes[assigned_indexs]
  106. cls_targets.append(cls_target)
  107. box_targets.append(box_target)
  108. obj_targets.append(obj_target)
  109. fg_masks.append(fg_mask)
  110. cls_targets = torch.cat(cls_targets, 0)
  111. box_targets = torch.cat(box_targets, 0)
  112. obj_targets = torch.cat(obj_targets, 0)
  113. fg_masks = torch.cat(fg_masks, 0)
  114. num_fgs = fg_masks.sum()
  115. if is_dist_avail_and_initialized():
  116. torch.distributed.all_reduce(num_fgs)
  117. num_fgs = (num_fgs / get_world_size()).clamp(1.0)
  118. # ------------------ Objecntness loss ------------------
  119. loss_obj = self.loss_objectness(obj_preds.view(-1, 1), obj_targets.float())
  120. loss_obj = loss_obj.sum() / num_fgs
  121. # ------------------ Classification loss ------------------
  122. cls_preds_pos = cls_preds.view(-1, self.num_classes)[fg_masks]
  123. loss_cls = self.loss_classes(cls_preds_pos, cls_targets)
  124. loss_cls = loss_cls.sum() / num_fgs
  125. # ------------------ Regression loss ------------------
  126. box_preds_pos = box_preds.view(-1, 4)[fg_masks]
  127. loss_box = self.loss_bboxes(box_preds_pos, box_targets)
  128. loss_box = loss_box.sum() / num_fgs
  129. # total loss
  130. losses = self.loss_obj_weight * loss_obj + \
  131. self.loss_cls_weight * loss_cls + \
  132. self.loss_box_weight * loss_box
  133. loss_dict = dict(
  134. loss_obj = loss_obj,
  135. loss_cls = loss_cls,
  136. loss_box = loss_box,
  137. losses = losses
  138. )
  139. # ------------------ Aux regression loss ------------------
  140. loss_box_aux = None
  141. if epoch >= (self.max_epoch - self.no_aug_epoch - 1):
  142. ## reg_preds
  143. reg_preds = torch.cat(outputs['pred_reg'], dim=1)
  144. reg_preds_pos = reg_preds.view(-1, 4)[fg_masks]
  145. ## anchor tensors
  146. anchors_tensors = torch.cat(outputs['anchors'], dim=0)[None].repeat(bs, 1, 1)
  147. anchors_tensors_pos = anchors_tensors.view(-1, 2)[fg_masks]
  148. ## stride tensors
  149. stride_tensors = torch.cat(outputs['stride_tensors'], dim=0)[None].repeat(bs, 1, 1)
  150. stride_tensors_pos = stride_tensors.view(-1, 1)[fg_masks]
  151. ## aux loss
  152. loss_box_aux = self.loss_bboxes_aux(reg_preds_pos, box_targets, anchors_tensors_pos, stride_tensors_pos)
  153. loss_box_aux = loss_box_aux.sum() / num_fgs
  154. losses += loss_box_aux
  155. # Loss dict
  156. if loss_box_aux is None:
  157. loss_dict = dict(
  158. loss_obj = loss_obj,
  159. loss_cls = loss_cls,
  160. loss_box = loss_box,
  161. losses = losses
  162. )
  163. else:
  164. loss_dict = dict(
  165. loss_obj = loss_obj,
  166. loss_cls = loss_cls,
  167. loss_box = loss_box,
  168. loss_box_aux = loss_box_aux,
  169. losses = losses
  170. )
  171. return loss_dict
  172. def build_criterion(args, cfg, device, num_classes):
  173. criterion = Criterion(
  174. args=args,
  175. cfg=cfg,
  176. device=device,
  177. num_classes=num_classes
  178. )
  179. return criterion
  180. if __name__ == "__main__":
  181. pass