loss.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from .matcher import TaskAlignedAssigner
  5. from utils.box_ops import bbox2dist, bbox_iou
  6. class Criterion(object):
  7. def __init__(self,
  8. cfg,
  9. device,
  10. num_classes=80):
  11. self.cfg = cfg
  12. self.device = device
  13. self.num_classes = num_classes
  14. self.reg_max = cfg['reg_max']
  15. self.use_dfl = cfg['reg_max'] > 1
  16. # loss
  17. self.cls_lossf = ClassificationLoss(cfg, reduction='none')
  18. self.reg_lossf = RegressionLoss(num_classes, cfg['reg_max'] - 1, self.use_dfl)
  19. # loss weight
  20. self.loss_cls_weight = cfg['loss_cls_weight']
  21. self.loss_iou_weight = cfg['loss_iou_weight']
  22. self.loss_dfl_weight = cfg['loss_dfl_weight']
  23. # matcher
  24. matcher_config = cfg['matcher']
  25. self.matcher = TaskAlignedAssigner(
  26. topk=matcher_config['topk'],
  27. num_classes=num_classes,
  28. alpha=matcher_config['alpha'],
  29. beta=matcher_config['beta']
  30. )
  31. def __call__(self, outputs, targets):
  32. """
  33. outputs['pred_cls']: List(Tensor) [B, M, C]
  34. outputs['pred_regs']: List(Tensor) [B, M, 4*(reg_max+1)]
  35. outputs['pred_boxs']: List(Tensor) [B, M, 4]
  36. outputs['anchors']: List(Tensor) [M, 2]
  37. outputs['strides']: List(Int) [8, 16, 32] output stride
  38. outputs['stride_tensor']: List(Tensor) [M, 1]
  39. targets: (List) [dict{'boxes': [...],
  40. 'labels': [...],
  41. 'orig_size': ...}, ...]
  42. """
  43. bs = outputs['pred_cls'][0].shape[0]
  44. device = outputs['pred_cls'][0].device
  45. strides = outputs['stride_tensor']
  46. anchors = outputs['anchors']
  47. anchors = torch.cat(anchors, dim=0)
  48. num_anchors = anchors.shape[0]
  49. # preds: [B, M, C]
  50. cls_preds = torch.cat(outputs['pred_cls'], dim=1)
  51. reg_preds = torch.cat(outputs['pred_reg'], dim=1)
  52. box_preds = torch.cat(outputs['pred_box'], dim=1)
  53. # label assignment
  54. gt_label_targets = []
  55. gt_score_targets = []
  56. gt_bbox_targets = []
  57. fg_masks = []
  58. for batch_idx in range(bs):
  59. tgt_labels = targets[batch_idx]["labels"].to(device) # [Mp,]
  60. tgt_boxs = targets[batch_idx]["boxes"].to(device) # [Mp, 4]
  61. # check target
  62. if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
  63. # There is no valid gt
  64. fg_mask = cls_preds.new_zeros(1, num_anchors).bool() #[1, M,]
  65. gt_label = cls_preds.new_zeros((1, num_anchors,)) #[1, M,]
  66. gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
  67. gt_box = cls_preds.new_zeros((1, num_anchors, 4)) #[1, M, 4]
  68. else:
  69. tgt_labels = tgt_labels[None, :, None] # [1, Mp, 1]
  70. tgt_boxs = tgt_boxs[None] # [1, Mp, 4]
  71. (
  72. gt_label, #[1, M]
  73. gt_box, #[1, M, 4]
  74. gt_score, #[1, M, C]
  75. fg_mask, #[1, M,]
  76. _
  77. ) = self.matcher(
  78. pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(),
  79. pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
  80. anc_points = anchors,
  81. gt_labels = tgt_labels,
  82. gt_bboxes = tgt_boxs
  83. )
  84. gt_label_targets.append(gt_label)
  85. gt_score_targets.append(gt_score)
  86. gt_bbox_targets.append(gt_box)
  87. fg_masks.append(fg_mask)
  88. # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
  89. fg_masks = torch.cat(fg_masks, 0).view(-1) # [BM,]
  90. gt_label_targets = torch.cat(gt_label_targets, 0).view(-1) # [BM,]
  91. gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes) # [BM, C]
  92. gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4) # [BM, 4]
  93. # cls loss
  94. cls_preds = cls_preds.view(-1, self.num_classes)
  95. gt_label_targets = torch.where(
  96. fg_masks > 0,
  97. gt_label_targets,
  98. torch.full_like(gt_label_targets, self.num_classes)
  99. )
  100. gt_labels_one_hot = F.one_hot(gt_label_targets.long(), self.num_classes + 1)[..., :-1]
  101. loss_cls = self.cls_lossf(cls_preds, gt_score_targets, gt_labels_one_hot)
  102. # reg loss
  103. anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2) # [BM, 2]
  104. strides = torch.cat(strides, dim=0).unsqueeze(0).repeat(bs, 1, 1).view(-1, 1) # [BM, 1]
  105. bbox_weight = gt_score_targets[fg_masks].sum(-1, keepdim=True) # [BM, 1]
  106. reg_preds = reg_preds.view(-1, 4*self.reg_max) # [BM, 4*(reg_max + 1)]
  107. box_preds = box_preds.view(-1, 4) # [BM, 4]
  108. loss_iou, loss_dfl = self.reg_lossf(
  109. pred_regs = reg_preds,
  110. pred_boxs = box_preds,
  111. anchors = anchors,
  112. gt_boxs = gt_bbox_targets,
  113. bbox_weight = bbox_weight,
  114. fg_masks = fg_masks,
  115. strides = strides,
  116. )
  117. # normalize loss
  118. gt_score_targets_sum = max(gt_score_targets.sum(), 1)
  119. loss_cls = loss_cls.sum() / gt_score_targets_sum
  120. loss_iou = loss_iou.sum() / gt_score_targets_sum
  121. loss_dfl = loss_dfl.sum() / gt_score_targets_sum
  122. # total loss
  123. losses = loss_cls * self.loss_cls_weight + \
  124. loss_iou * self.loss_iou_weight
  125. if self.use_dfl:
  126. losses += loss_dfl * self.loss_dfl_weight
  127. loss_dict = dict(
  128. loss_cls = loss_cls,
  129. loss_iou = loss_iou,
  130. loss_dfl = loss_dfl,
  131. losses = losses
  132. )
  133. else:
  134. loss_dict = dict(
  135. loss_cls = loss_cls,
  136. loss_iou = loss_iou,
  137. losses = losses
  138. )
  139. return loss_dict
  140. class ClassificationLoss(nn.Module):
  141. def __init__(self, cfg, reduction='none'):
  142. super(ClassificationLoss, self).__init__()
  143. self.cfg = cfg
  144. self.reduction = reduction
  145. # For VFL
  146. self.alpha = 0.75
  147. self.gamma = 2.0
  148. def varifocalloss(self, pred_logits, gt_score, gt_label, alpha=0.75, gamma=2.0):
  149. focal_weight = alpha * pred_logits.sigmoid().pow(gamma) * (1 - gt_label) + gt_score * gt_label
  150. with torch.cuda.amp.autocast(enabled=False):
  151. bce_loss = F.binary_cross_entropy_with_logits(
  152. pred_logits.float(), gt_score.float(), reduction='none')
  153. loss = bce_loss * focal_weight
  154. if self.reduction == 'sum':
  155. loss = loss.sum()
  156. elif self.reduction == 'mean':
  157. loss = loss.mean()
  158. return loss
  159. def binary_cross_entropy(self, pred_logits, gt_score):
  160. loss = F.binary_cross_entropy_with_logits(
  161. pred_logits.float(), gt_score.float(), reduction='none')
  162. if self.reduction == 'sum':
  163. loss = loss.sum()
  164. elif self.reduction == 'mean':
  165. loss = loss.mean()
  166. return loss
  167. def forward(self, pred_logits, gt_score, gt_label):
  168. if self.cfg['cls_loss'] == 'bce':
  169. return self.binary_cross_entropy(pred_logits, gt_score)
  170. elif self.cfg['cls_loss'] == 'vfl':
  171. return self.varifocalloss(pred_logits, gt_score, gt_label, self.alpha, self.gamma)
  172. class RegressionLoss(nn.Module):
  173. def __init__(self, num_classes, reg_max, use_dfl):
  174. super(RegressionLoss, self).__init__()
  175. self.num_classes = num_classes
  176. self.reg_max = reg_max
  177. self.use_dfl = use_dfl
  178. def df_loss(self, pred_regs, target):
  179. gt_left = target.to(torch.long)
  180. gt_right = gt_left + 1
  181. weight_left = gt_right.to(torch.float) - target
  182. weight_right = 1 - weight_left
  183. # loss left
  184. loss_left = F.cross_entropy(
  185. pred_regs.view(-1, self.reg_max + 1),
  186. gt_left.view(-1),
  187. reduction='none').view(gt_left.shape) * weight_left
  188. # loss right
  189. loss_right = F.cross_entropy(
  190. pred_regs.view(-1, self.reg_max + 1),
  191. gt_right.view(-1),
  192. reduction='none').view(gt_left.shape) * weight_right
  193. loss = (loss_left + loss_right).mean(-1, keepdim=True)
  194. return loss
  195. def forward(self, pred_regs, pred_boxs, anchors, gt_boxs, bbox_weight, fg_masks, strides):
  196. """
  197. Input:
  198. pred_regs: (Tensor) [BM, 4*(reg_max + 1)]
  199. pred_boxs: (Tensor) [BM, 4]
  200. anchors: (Tensor) [BM, 2]
  201. gt_boxs: (Tensor) [BM, 4]
  202. bbox_weight: (Tensor) [BM, 1]
  203. fg_masks: (Tensor) [BM,]
  204. strides: (Tensor) [BM, 1]
  205. """
  206. # select positive samples mask
  207. num_pos = fg_masks.sum()
  208. if num_pos > 0:
  209. pred_boxs_pos = pred_boxs[fg_masks]
  210. gt_boxs_pos = gt_boxs[fg_masks]
  211. # iou loss
  212. ious = bbox_iou(pred_boxs_pos,
  213. gt_boxs_pos,
  214. xywh=False,
  215. CIoU=True)
  216. loss_iou = (1.0 - ious) * bbox_weight
  217. # dfl loss
  218. if self.use_dfl:
  219. pred_regs_pos = pred_regs[fg_masks]
  220. gt_boxs_s = gt_boxs / strides
  221. anchors_s = anchors / strides
  222. gt_ltrb_s = bbox2dist(anchors_s, gt_boxs_s, self.reg_max)
  223. gt_ltrb_s_pos = gt_ltrb_s[fg_masks]
  224. loss_dfl = self.df_loss(pred_regs_pos, gt_ltrb_s_pos)
  225. loss_dfl *= bbox_weight
  226. else:
  227. loss_dfl = pred_regs.sum() * 0.
  228. else:
  229. loss_iou = pred_regs.sum() * 0.
  230. loss_dfl = pred_regs.sum() * 0.
  231. return loss_iou, loss_dfl
  232. def build_criterion(cfg, device, num_classes):
  233. criterion = Criterion(
  234. cfg=cfg,
  235. device=device,
  236. num_classes=num_classes
  237. )
  238. return criterion
  239. if __name__ == "__main__":
  240. pass