loss.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from .matcher import TaskAlignedAssigner, Yolov5Matcher
  5. from utils.box_ops import bbox_iou, get_ious
  6. from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
  7. class Criterion(object):
  8. def __init__(self,
  9. cfg,
  10. device,
  11. num_classes=80,
  12. warmup_epoch=1):
  13. # ------------------ Basic Parameters ------------------
  14. self.cfg = cfg
  15. self.device = device
  16. self.num_classes = num_classes
  17. self.warmup_epoch = warmup_epoch
  18. # ------------------ Loss Parameters ------------------
  19. ## loss function
  20. self.cls_lossf = ClassificationLoss(cfg, reduction='none')
  21. self.reg_lossf = RegressionLoss(num_classes)
  22. ## loss coeff
  23. self.loss_cls_weight = cfg['loss_cls_weight']
  24. self.loss_iou_weight = cfg['loss_iou_weight']
  25. # ------------------ Label Assigner ------------------
  26. matcher_config = cfg['matcher']
  27. ## matcher-1
  28. self.fixed_matcher = Yolov5Matcher(
  29. num_classes=num_classes,
  30. num_anchors=3,
  31. anchor_size=cfg['anchor_size'],
  32. anchor_theshold=matcher_config['anchor_thresh']
  33. )
  34. ## matcher-2
  35. self.dynamic_matcher = TaskAlignedAssigner(
  36. topk=matcher_config['topk'],
  37. num_classes=num_classes,
  38. alpha=matcher_config['alpha'],
  39. beta=matcher_config['beta']
  40. )
  41. def fixed_assignment_loss(self, outputs, targets):
  42. device = outputs['pred_cls'][0].device
  43. fpn_strides = outputs['strides']
  44. fmp_sizes = outputs['fmp_sizes']
  45. (
  46. gt_objectness,
  47. gt_classes,
  48. gt_bboxes,
  49. ) = self.fixed_matcher(fmp_sizes=fmp_sizes,
  50. fpn_strides=fpn_strides,
  51. targets=targets)
  52. # List[B, M, C] -> [B, M, C] -> [BM, C]
  53. pred_cls = torch.cat(outputs['pred_cls'], dim=1).view(-1, self.num_classes) # [BM, C]
  54. pred_box = torch.cat(outputs['pred_box'], dim=1).view(-1, 4) # [BM, 4]
  55. gt_objectness = gt_objectness.view(-1).to(device).float() # [BM,]
  56. gt_classes = gt_classes.view(-1, self.num_classes).to(device).float() # [BM, C]
  57. gt_bboxes = gt_bboxes.view(-1, 4).to(device).float() # [BM, 4]
  58. pos_masks = (gt_objectness > 0)
  59. num_fgs = pos_masks.sum()
  60. if is_dist_avail_and_initialized():
  61. torch.distributed.all_reduce(num_fgs)
  62. num_fgs = (num_fgs / get_world_size()).clamp(1.0)
  63. # box loss
  64. ious = get_ious(pred_box[pos_masks],
  65. gt_bboxes[pos_masks],
  66. box_mode="xyxy",
  67. iou_type='giou')
  68. loss_box = 1.0 - ious
  69. loss_box = loss_box.sum() / num_fgs
  70. # cls loss
  71. gt_classes[pos_masks] = gt_classes[pos_masks] * ious.unsqueeze(-1).clamp(0.)
  72. loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_classes, reduction='none')
  73. loss_cls = loss_cls.sum() / num_fgs
  74. # total loss
  75. losses = self.loss_cls_weight * loss_cls + \
  76. self.loss_iou_weight * loss_box
  77. loss_dict = dict(
  78. loss_cls = loss_cls,
  79. loss_box = loss_box,
  80. losses = losses
  81. )
  82. return loss_dict
  83. def dynamic_assignment_loss(self, outputs, targets):
  84. bs = outputs['pred_cls'][0].shape[0]
  85. device = outputs['pred_cls'][0].device
  86. anchors = outputs['anchors']
  87. anchors = torch.cat(anchors, dim=0)
  88. num_anchors = anchors.shape[0]
  89. # preds: [B, M, C]
  90. cls_preds = torch.cat(outputs['pred_cls'], dim=1)
  91. box_preds = torch.cat(outputs['pred_box'], dim=1)
  92. # label assignment
  93. gt_score_targets = []
  94. gt_bbox_targets = []
  95. fg_masks = []
  96. for batch_idx in range(bs):
  97. tgt_labels = targets[batch_idx]["labels"].to(device) # [Mp,]
  98. tgt_boxs = targets[batch_idx]["boxes"].to(device) # [Mp, 4]
  99. # check target
  100. if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
  101. # There is no valid gt
  102. fg_mask = cls_preds.new_zeros(1, num_anchors).bool() #[1, M,]
  103. gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
  104. gt_box = cls_preds.new_zeros((1, num_anchors, 4)) #[1, M, 4]
  105. else:
  106. tgt_labels = tgt_labels[None, :, None] # [1, Mp, 1]
  107. tgt_boxs = tgt_boxs[None] # [1, Mp, 4]
  108. (
  109. _,
  110. gt_box, #[1, M, 4]
  111. gt_score, #[1, M, C]
  112. fg_mask, #[1, M,]
  113. _
  114. ) = self.dynamic_matcher(
  115. pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(),
  116. pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
  117. anc_points = anchors[..., :2],
  118. gt_labels = tgt_labels,
  119. gt_bboxes = tgt_boxs
  120. )
  121. gt_score_targets.append(gt_score)
  122. gt_bbox_targets.append(gt_box)
  123. fg_masks.append(fg_mask)
  124. # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
  125. fg_masks = torch.cat(fg_masks, 0).view(-1) # [BM,]
  126. gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes) # [BM, C]
  127. gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4) # [BM, 4]
  128. # cls loss
  129. cls_preds = cls_preds.view(-1, self.num_classes)
  130. loss_cls = self.cls_lossf(cls_preds, gt_score_targets)
  131. # reg loss
  132. bbox_weight = gt_score_targets[fg_masks].sum(-1, keepdim=True) # [BM, 1]
  133. box_preds = box_preds.view(-1, 4) # [BM, 4]
  134. loss_iou = self.reg_lossf(
  135. pred_boxs = box_preds,
  136. gt_boxs = gt_bbox_targets,
  137. bbox_weight = bbox_weight,
  138. fg_masks = fg_masks
  139. )
  140. num_fgs = gt_score_targets.sum()
  141. if is_dist_avail_and_initialized():
  142. torch.distributed.all_reduce(num_fgs)
  143. num_fgs = (num_fgs / get_world_size()).clamp(1.0)
  144. # normalize loss
  145. loss_cls = loss_cls.sum() / num_fgs
  146. loss_iou = loss_iou.sum() / num_fgs
  147. # total loss
  148. losses = loss_cls * self.loss_cls_weight + \
  149. loss_iou * self.loss_iou_weight
  150. loss_dict = dict(
  151. loss_cls = loss_cls,
  152. loss_iou = loss_iou,
  153. losses = losses
  154. )
  155. return loss_dict
  156. def __call__(self, outputs, targets, epoch=0):
  157. """
  158. outputs['pred_cls']: List(Tensor) [B, M, C]
  159. outputs['pred_regs']: List(Tensor) [B, M, 4*(reg_max+1)]
  160. outputs['pred_boxs']: List(Tensor) [B, M, 4]
  161. outputs['anchors']: List(Tensor) [M, 2]
  162. outputs['strides']: List(Int) [8, 16, 32] output stride
  163. outputs['stride_tensor']: List(Tensor) [M, 1]
  164. targets: (List) [dict{'boxes': [...],
  165. 'labels': [...],
  166. 'orig_size': ...}, ...]
  167. """
  168. # Fixed LA stage
  169. if epoch < self.warmup_epoch:
  170. return self.fixed_assignment_loss(outputs, targets)
  171. # Switch to Dynamic LA stage
  172. elif epoch == self.warmup_epoch:
  173. print('Switch to Dynamic Label Assignment.')
  174. return self.dynamic_assignment_loss(outputs, targets)
  175. # Dynamic LA stage
  176. else:
  177. return self.dynamic_assignment_loss(outputs, targets)
  178. class ClassificationLoss(nn.Module):
  179. def __init__(self, cfg, reduction='none'):
  180. super(ClassificationLoss, self).__init__()
  181. self.cfg = cfg
  182. self.reduction = reduction
  183. def binary_cross_entropy(self, pred_logits, gt_score):
  184. loss = F.binary_cross_entropy_with_logits(
  185. pred_logits.float(), gt_score.float(), reduction='none')
  186. if self.reduction == 'sum':
  187. loss = loss.sum()
  188. elif self.reduction == 'mean':
  189. loss = loss.mean()
  190. return loss
  191. def forward(self, pred_logits, gt_score):
  192. if self.cfg['cls_loss'] == 'bce':
  193. return self.binary_cross_entropy(pred_logits, gt_score)
  194. class RegressionLoss(nn.Module):
  195. def __init__(self, num_classes):
  196. super(RegressionLoss, self).__init__()
  197. self.num_classes = num_classes
  198. def forward(self, pred_boxs, gt_boxs, bbox_weight, fg_masks):
  199. """
  200. Input:
  201. pred_boxs: (Tensor) [BM, 4]
  202. anchors: (Tensor) [BM, 2]
  203. gt_boxs: (Tensor) [BM, 4]
  204. bbox_weight: (Tensor) [BM, 1]
  205. fg_masks: (Tensor) [BM,]
  206. strides: (Tensor) [BM, 1]
  207. """
  208. # select positive samples mask
  209. num_pos = fg_masks.sum()
  210. if num_pos > 0:
  211. pred_boxs_pos = pred_boxs[fg_masks]
  212. gt_boxs_pos = gt_boxs[fg_masks]
  213. # iou loss
  214. ious = bbox_iou(pred_boxs_pos,
  215. gt_boxs_pos,
  216. xywh=False,
  217. CIoU=True)
  218. loss_iou = (1.0 - ious) * bbox_weight
  219. else:
  220. loss_iou = pred_boxs.sum() * 0.
  221. return loss_iou
  222. def build_criterion(cfg, device, num_classes, warmup_epoch=1):
  223. criterion = Criterion(
  224. cfg=cfg,
  225. device=device,
  226. num_classes=num_classes,
  227. warmup_epoch=warmup_epoch,
  228. )
  229. return criterion
  230. if __name__ == "__main__":
  231. pass