loss.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from .matcher import TaskAlignedAssigner, Yolov5Matcher
  5. from utils.box_ops import bbox_iou, get_ious
  6. from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
  7. class Criterion(object):
  8. def __init__(self,
  9. cfg,
  10. device,
  11. num_classes=80,
  12. warmup_epoch=1):
  13. # ------------------ Basic Parameters ------------------
  14. self.cfg = cfg
  15. self.device = device
  16. self.num_classes = num_classes
  17. self.warmup_epoch = warmup_epoch
  18. self.warmup_stage = True
  19. # ------------------ Loss Parameters ------------------
  20. ## loss function
  21. self.cls_lossf = ClassificationLoss(cfg, reduction='none')
  22. self.reg_lossf = RegressionLoss(num_classes)
  23. ## loss coeff
  24. self.loss_cls_weight = cfg['loss_cls_weight']
  25. self.loss_iou_weight = cfg['loss_iou_weight']
  26. # ------------------ Label Assigner ------------------
  27. matcher_config = cfg['matcher']
  28. ## matcher-1
  29. self.fixed_matcher = Yolov5Matcher(
  30. num_classes=num_classes,
  31. num_anchors=3,
  32. anchor_size=cfg['anchor_size'],
  33. anchor_theshold=matcher_config['anchor_thresh']
  34. )
  35. ## matcher-2
  36. self.dynamic_matcher = TaskAlignedAssigner(
  37. topk=matcher_config['topk'],
  38. num_classes=num_classes,
  39. alpha=matcher_config['alpha'],
  40. beta=matcher_config['beta']
  41. )
  42. def fixed_assignment_loss(self, outputs, targets):
  43. device = outputs['pred_cls'][0].device
  44. fpn_strides = outputs['strides']
  45. fmp_sizes = outputs['fmp_sizes']
  46. (
  47. gt_objectness,
  48. gt_classes,
  49. gt_bboxes,
  50. ) = self.fixed_matcher(fmp_sizes=fmp_sizes,
  51. fpn_strides=fpn_strides,
  52. targets=targets)
  53. # List[B, M, C] -> [B, M, C] -> [BM, C]
  54. pred_cls = torch.cat(outputs['pred_cls'], dim=1).view(-1, self.num_classes) # [BM, C]
  55. pred_box = torch.cat(outputs['pred_box'], dim=1).view(-1, 4) # [BM, 4]
  56. gt_objectness = gt_objectness.view(-1).to(device).float() # [BM,]
  57. gt_classes = gt_classes.view(-1, self.num_classes).to(device).float() # [BM, C]
  58. gt_bboxes = gt_bboxes.view(-1, 4).to(device).float() # [BM, 4]
  59. pos_masks = (gt_objectness > 0)
  60. num_fgs = pos_masks.sum()
  61. if is_dist_avail_and_initialized():
  62. torch.distributed.all_reduce(num_fgs)
  63. num_fgs = (num_fgs / get_world_size()).clamp(1.0)
  64. # box loss
  65. ious = get_ious(pred_box[pos_masks],
  66. gt_bboxes[pos_masks],
  67. box_mode="xyxy",
  68. iou_type='giou')
  69. loss_box = 1.0 - ious
  70. loss_box = loss_box.sum() / num_fgs
  71. # cls loss
  72. gt_classes[pos_masks] = gt_classes[pos_masks] * ious.unsqueeze(-1).clamp(0.)
  73. loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_classes, reduction='none')
  74. loss_cls = loss_cls.sum() / num_fgs
  75. # total loss
  76. losses = self.loss_cls_weight * loss_cls + \
  77. self.loss_iou_weight * loss_box
  78. loss_dict = dict(
  79. loss_cls = loss_cls,
  80. loss_box = loss_box,
  81. losses = losses
  82. )
  83. return loss_dict
  84. def dynamic_assignment_loss(self, outputs, targets):
  85. bs = outputs['pred_cls'][0].shape[0]
  86. device = outputs['pred_cls'][0].device
  87. anchors = outputs['anchors']
  88. anchors = torch.cat(anchors, dim=0)
  89. num_anchors = anchors.shape[0]
  90. # preds: [B, M, C]
  91. cls_preds = torch.cat(outputs['pred_cls'], dim=1)
  92. box_preds = torch.cat(outputs['pred_box'], dim=1)
  93. # label assignment
  94. gt_score_targets = []
  95. gt_bbox_targets = []
  96. fg_masks = []
  97. for batch_idx in range(bs):
  98. tgt_labels = targets[batch_idx]["labels"].to(device) # [Mp,]
  99. tgt_boxs = targets[batch_idx]["boxes"].to(device) # [Mp, 4]
  100. # check target
  101. if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
  102. # There is no valid gt
  103. fg_mask = cls_preds.new_zeros(1, num_anchors).bool() #[1, M,]
  104. gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
  105. gt_box = cls_preds.new_zeros((1, num_anchors, 4)) #[1, M, 4]
  106. else:
  107. tgt_labels = tgt_labels[None, :, None] # [1, Mp, 1]
  108. tgt_boxs = tgt_boxs[None] # [1, Mp, 4]
  109. (
  110. _,
  111. gt_box, #[1, M, 4]
  112. gt_score, #[1, M, C]
  113. fg_mask, #[1, M,]
  114. _
  115. ) = self.dynamic_matcher(
  116. pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(),
  117. pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
  118. anc_points = anchors[..., :2],
  119. gt_labels = tgt_labels,
  120. gt_bboxes = tgt_boxs
  121. )
  122. gt_score_targets.append(gt_score)
  123. gt_bbox_targets.append(gt_box)
  124. fg_masks.append(fg_mask)
  125. # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
  126. fg_masks = torch.cat(fg_masks, 0).view(-1) # [BM,]
  127. gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes) # [BM, C]
  128. gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4) # [BM, 4]
  129. # cls loss
  130. cls_preds = cls_preds.view(-1, self.num_classes)
  131. loss_cls = self.cls_lossf(cls_preds, gt_score_targets)
  132. # reg loss
  133. bbox_weight = gt_score_targets[fg_masks].sum(-1, keepdim=True) # [BM, 1]
  134. box_preds = box_preds.view(-1, 4) # [BM, 4]
  135. loss_iou = self.reg_lossf(
  136. pred_boxs = box_preds,
  137. gt_boxs = gt_bbox_targets,
  138. bbox_weight = bbox_weight,
  139. fg_masks = fg_masks
  140. )
  141. num_fgs = gt_score_targets.sum()
  142. if is_dist_avail_and_initialized():
  143. torch.distributed.all_reduce(num_fgs)
  144. num_fgs = (num_fgs / get_world_size()).clamp(1.0)
  145. # normalize loss
  146. loss_cls = loss_cls.sum() / num_fgs
  147. loss_iou = loss_iou.sum() / num_fgs
  148. # total loss
  149. losses = loss_cls * self.loss_cls_weight + \
  150. loss_iou * self.loss_iou_weight
  151. loss_dict = dict(
  152. loss_cls = loss_cls,
  153. loss_iou = loss_iou,
  154. losses = losses
  155. )
  156. return loss_dict
  157. def __call__(self, outputs, targets, epoch=0):
  158. """
  159. outputs['pred_cls']: List(Tensor) [B, M, C]
  160. outputs['pred_regs']: List(Tensor) [B, M, 4*(reg_max+1)]
  161. outputs['pred_boxs']: List(Tensor) [B, M, 4]
  162. outputs['anchors']: List(Tensor) [M, 2]
  163. outputs['strides']: List(Int) [8, 16, 32] output stride
  164. outputs['stride_tensor']: List(Tensor) [M, 1]
  165. targets: (List) [dict{'boxes': [...],
  166. 'labels': [...],
  167. 'orig_size': ...}, ...]
  168. """
  169. # Fixed LA stage
  170. if epoch < self.warmup_epoch:
  171. return self.fixed_assignment_loss(outputs, targets)
  172. # Switch to Dynamic LA stage
  173. elif epoch == self.warmup_epoch:
  174. print('Switch to Dynamic Label Assignment.')
  175. return self.dynamic_assignment_loss(outputs, targets)
  176. # Dynamic LA stage
  177. else:
  178. return self.dynamic_assignment_loss(outputs, targets)
  179. class ClassificationLoss(nn.Module):
  180. def __init__(self, cfg, reduction='none'):
  181. super(ClassificationLoss, self).__init__()
  182. self.cfg = cfg
  183. self.reduction = reduction
  184. def binary_cross_entropy(self, pred_logits, gt_score):
  185. loss = F.binary_cross_entropy_with_logits(
  186. pred_logits.float(), gt_score.float(), reduction='none')
  187. if self.reduction == 'sum':
  188. loss = loss.sum()
  189. elif self.reduction == 'mean':
  190. loss = loss.mean()
  191. return loss
  192. def forward(self, pred_logits, gt_score):
  193. if self.cfg['cls_loss'] == 'bce':
  194. return self.binary_cross_entropy(pred_logits, gt_score)
  195. class RegressionLoss(nn.Module):
  196. def __init__(self, num_classes):
  197. super(RegressionLoss, self).__init__()
  198. self.num_classes = num_classes
  199. def forward(self, pred_boxs, gt_boxs, bbox_weight, fg_masks):
  200. """
  201. Input:
  202. pred_boxs: (Tensor) [BM, 4]
  203. anchors: (Tensor) [BM, 2]
  204. gt_boxs: (Tensor) [BM, 4]
  205. bbox_weight: (Tensor) [BM, 1]
  206. fg_masks: (Tensor) [BM,]
  207. strides: (Tensor) [BM, 1]
  208. """
  209. # select positive samples mask
  210. num_pos = fg_masks.sum()
  211. if num_pos > 0:
  212. pred_boxs_pos = pred_boxs[fg_masks]
  213. gt_boxs_pos = gt_boxs[fg_masks]
  214. # iou loss
  215. ious = bbox_iou(pred_boxs_pos,
  216. gt_boxs_pos,
  217. xywh=False,
  218. CIoU=True)
  219. loss_iou = (1.0 - ious) * bbox_weight
  220. else:
  221. loss_iou = pred_boxs.sum() * 0.
  222. return loss_iou
  223. def build_criterion(cfg, device, num_classes, warmup_epoch=1):
  224. criterion = Criterion(
  225. cfg=cfg,
  226. device=device,
  227. num_classes=num_classes,
  228. warmup_epoch=warmup_epoch,
  229. )
  230. return criterion
  231. if __name__ == "__main__":
  232. pass