loss.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. import torch
  2. import torch.nn.functional as F
  3. from utils.box_ops import bbox2dist, get_ious
  4. from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
  5. from .matcher import TaskAlignedAssigner, AlignedSimOTA
  6. class Criterion(object):
  7. def __init__(self, args, cfg, device, num_classes=80):
  8. self.cfg = cfg
  9. self.args = args
  10. self.device = device
  11. self.num_classes = num_classes
  12. self.use_ema_update = cfg['ema_update']
  13. # loss weight
  14. self.loss_cls_weight = cfg['loss_cls_weight']
  15. self.loss_box_weight = cfg['loss_box_weight']
  16. self.loss_dfl_weight = cfg['loss_dfl_weight']
  17. # matcher
  18. matcher_config = cfg['matcher']
  19. self.tal_matcher = TaskAlignedAssigner(
  20. topk=matcher_config['tal']['topk'],
  21. alpha=matcher_config['tal']['alpha'],
  22. beta=matcher_config['tal']['beta'],
  23. num_classes=num_classes
  24. )
  25. self.ota_matcher = AlignedSimOTA(
  26. center_sampling_radius=matcher_config['ota']['center_sampling_radius'],
  27. topk_candidate=matcher_config['ota']['topk_candidate'],
  28. num_classes=num_classes
  29. )
  30. def __call__(self, outputs, targets, epoch=0):
  31. if epoch < self.args.wp_epoch:
  32. return self.ota_loss(outputs, targets)
  33. else:
  34. return self.tal_loss(outputs, targets)
  35. def ema_update(self, name: str, value, initial_value, momentum=0.9):
  36. if hasattr(self, name):
  37. old = getattr(self, name)
  38. else:
  39. old = initial_value
  40. new = old * momentum + value * (1 - momentum)
  41. setattr(self, name, new)
  42. return new
  43. # ----------------- Loss functions -----------------
  44. def loss_classes(self, pred_cls, gt_score, gt_label=None, vfl=False):
  45. if vfl:
  46. assert gt_label is not None
  47. # compute varifocal loss
  48. alpha, gamma = 0.75, 2.0
  49. focal_weight = alpha * pred_cls.sigmoid().pow(gamma) * (1 - gt_label) + gt_score * gt_label
  50. bce_loss = F.binary_cross_entropy_with_logits(pred_cls, gt_score, reduction='none')
  51. loss_cls = bce_loss * focal_weight
  52. else:
  53. # compute bce loss
  54. loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_score, reduction='none')
  55. return loss_cls
  56. def loss_bboxes(self, pred_box, gt_box, bbox_weight=None):
  57. # regression loss
  58. ious = get_ious(pred_box, gt_box, 'xyxy', 'giou')
  59. loss_box = 1.0 - ious
  60. if bbox_weight is not None:
  61. loss_box *= bbox_weight
  62. return loss_box
  63. def loss_dfl(self, pred_reg, gt_box, anchor, stride, bbox_weight=None):
  64. # rescale coords by stride
  65. gt_box_s = gt_box / stride
  66. anchor_s = anchor / stride
  67. # compute deltas
  68. gt_ltrb_s = bbox2dist(anchor_s, gt_box_s, self.cfg['reg_max'] - 1)
  69. gt_left = gt_ltrb_s.to(torch.long)
  70. gt_right = gt_left + 1
  71. weight_left = gt_right.to(torch.float) - gt_ltrb_s
  72. weight_right = 1 - weight_left
  73. # loss left
  74. loss_left = F.cross_entropy(
  75. pred_reg.view(-1, self.cfg['reg_max']),
  76. gt_left.view(-1),
  77. reduction='none').view(gt_left.shape) * weight_left
  78. # loss right
  79. loss_right = F.cross_entropy(
  80. pred_reg.view(-1, self.cfg['reg_max']),
  81. gt_right.view(-1),
  82. reduction='none').view(gt_left.shape) * weight_right
  83. loss_dfl = (loss_left + loss_right).mean(-1)
  84. if bbox_weight is not None:
  85. loss_dfl *= bbox_weight
  86. return loss_dfl
  87. # ----------------- Loss with TAL assigner -----------------
  88. def tal_loss(self, outputs, targets):
  89. """ Compute loss with TAL assigner """
  90. bs = outputs['pred_cls'][0].shape[0]
  91. device = outputs['pred_cls'][0].device
  92. anchors = torch.cat(outputs['anchors'], dim=0)
  93. num_anchors = anchors.shape[0]
  94. # preds: [B, M, C]
  95. cls_preds = torch.cat(outputs['pred_cls'], dim=1)
  96. reg_preds = torch.cat(outputs['pred_reg'], dim=1)
  97. box_preds = torch.cat(outputs['pred_box'], dim=1)
  98. # --------------- label assignment ---------------
  99. gt_label_targets = []
  100. gt_score_targets = []
  101. gt_bbox_targets = []
  102. fg_masks = []
  103. for batch_idx in range(bs):
  104. tgt_labels = targets[batch_idx]["labels"].to(device)
  105. tgt_bboxes = targets[batch_idx]["boxes"].to(device)
  106. # check target
  107. if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
  108. # There is no valid gt
  109. fg_mask = cls_preds.new_zeros(1, num_anchors).bool() #[1, M,]
  110. gt_label = cls_preds.new_zeros((1, num_anchors,)) #[1, M,]
  111. gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
  112. gt_box = cls_preds.new_zeros((1, num_anchors, 4)) #[1, M, 4]
  113. else:
  114. tgt_labels = tgt_labels[None, :, None] # [1, Mp, 1]
  115. tgt_bboxes = tgt_bboxes[None] # [1, Mp, 4]
  116. (
  117. gt_label, #[1, M]
  118. gt_box, #[1, M, 4]
  119. gt_score, #[1, M, C]
  120. fg_mask, #[1, M,]
  121. _
  122. ) = self.tal_matcher(
  123. pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(),
  124. pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
  125. anc_points = anchors,
  126. gt_labels = tgt_labels,
  127. gt_bboxes = tgt_bboxes
  128. )
  129. gt_label_targets.append(gt_label)
  130. gt_score_targets.append(gt_score)
  131. gt_bbox_targets.append(gt_box)
  132. fg_masks.append(fg_mask)
  133. # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
  134. fg_masks = torch.cat(fg_masks, 0).view(-1) # [BM,]
  135. gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes) # [BM, C]
  136. gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4) # [BM, 4]
  137. gt_label_targets = torch.cat(gt_label_targets, 0).view(-1) # [BM,]
  138. gt_label_targets = torch.where(fg_masks > 0, gt_label_targets, torch.full_like(gt_label_targets, self.num_classes))
  139. gt_labels_one_hot = F.one_hot(gt_label_targets.long(), self.num_classes + 1)[..., :-1]
  140. bbox_weight = gt_score_targets[fg_masks].sum(-1)
  141. num_fgs = max(gt_score_targets.sum(), 1)
  142. # average loss normalizer across all the GPUs
  143. if is_dist_avail_and_initialized():
  144. torch.distributed.all_reduce(num_fgs)
  145. num_fgs = max(num_fgs / get_world_size(), 1.0)
  146. # update loss normalizer with EMA
  147. if self.use_ema_update:
  148. normalizer = self.ema_update("loss_normalizer", max(num_fgs, 1), 100)
  149. else:
  150. normalizer = num_fgs
  151. # ------------------ Classification loss ------------------
  152. cls_preds = cls_preds.view(-1, self.num_classes)
  153. loss_cls = self.loss_classes(cls_preds, gt_score_targets, gt_labels_one_hot, vfl=False)
  154. loss_cls = loss_cls.sum() / normalizer
  155. # ------------------ Regression loss ------------------
  156. box_preds_pos = box_preds.view(-1, 4)[fg_masks]
  157. box_targets_pos = gt_bbox_targets[fg_masks]
  158. loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos, bbox_weight)
  159. loss_box = loss_box.sum() / normalizer
  160. # ------------------ Distribution focal loss ------------------
  161. ## process anchors
  162. anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
  163. ## process stride tensors
  164. strides = torch.cat(outputs['stride_tensor'], dim=0)
  165. strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
  166. ## fg preds
  167. reg_preds_pos = reg_preds.view(-1, 4*self.cfg['reg_max'])[fg_masks]
  168. anchors_pos = anchors[fg_masks]
  169. strides_pos = strides[fg_masks]
  170. ## compute dfl
  171. loss_dfl = self.loss_dfl(reg_preds_pos, box_targets_pos, anchors_pos, strides_pos, bbox_weight)
  172. loss_dfl = loss_dfl.sum() / normalizer
  173. # total loss
  174. losses = self.loss_cls_weight * loss_cls + \
  175. self.loss_box_weight * loss_box + \
  176. self.loss_dfl_weight * loss_dfl
  177. loss_dict = dict(
  178. loss_cls = loss_cls,
  179. loss_box = loss_box,
  180. loss_dfl = loss_dfl,
  181. losses = losses
  182. )
  183. return loss_dict
  184. # ----------------- Loss with SimOTA assigner -----------------
  185. def ota_loss(self, outputs, targets):
  186. """ Compute loss with SimOTA assigner """
  187. bs = outputs['pred_cls'][0].shape[0]
  188. device = outputs['pred_cls'][0].device
  189. fpn_strides = outputs['strides']
  190. anchors = outputs['anchors']
  191. num_anchors = sum([ab.shape[0] for ab in anchors])
  192. # preds: [B, M, C]
  193. cls_preds = torch.cat(outputs['pred_cls'], dim=1)
  194. reg_preds = torch.cat(outputs['pred_reg'], dim=1)
  195. box_preds = torch.cat(outputs['pred_box'], dim=1)
  196. # --------------- label assignment ---------------
  197. cls_targets = []
  198. box_targets = []
  199. fg_masks = []
  200. for batch_idx in range(bs):
  201. tgt_labels = targets[batch_idx]["labels"].to(device)
  202. tgt_bboxes = targets[batch_idx]["boxes"].to(device)
  203. # check target
  204. if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
  205. # There is no valid gt
  206. cls_target = cls_preds.new_zeros((num_anchors, self.num_classes))
  207. box_target = cls_preds.new_zeros((0, 4))
  208. fg_mask = cls_preds.new_zeros(num_anchors).bool()
  209. else:
  210. (
  211. fg_mask,
  212. assigned_labels,
  213. assigned_ious,
  214. assigned_indexs
  215. ) = self.ota_matcher(
  216. fpn_strides = fpn_strides,
  217. anchors = anchors,
  218. pred_cls = cls_preds[batch_idx],
  219. pred_box = box_preds[batch_idx],
  220. tgt_labels = tgt_labels,
  221. tgt_bboxes = tgt_bboxes
  222. )
  223. # prepare cls targets
  224. assigned_labels = F.one_hot(assigned_labels.long(), self.num_classes)
  225. assigned_labels = assigned_labels * assigned_ious.unsqueeze(-1)
  226. cls_target = assigned_labels.new_zeros((num_anchors, self.num_classes))
  227. cls_target[fg_mask] = assigned_labels
  228. # prepare box targets
  229. box_target = tgt_bboxes[assigned_indexs]
  230. cls_targets.append(cls_target)
  231. box_targets.append(box_target)
  232. fg_masks.append(fg_mask)
  233. cls_targets = torch.cat(cls_targets, 0)
  234. box_targets = torch.cat(box_targets, 0)
  235. fg_masks = torch.cat(fg_masks, 0)
  236. num_fgs = fg_masks.sum()
  237. # average loss normalizer across all the GPUs
  238. if is_dist_avail_and_initialized():
  239. torch.distributed.all_reduce(num_fgs)
  240. num_fgs = (num_fgs / get_world_size()).clamp(1.0)
  241. # update loss normalizer with EMA
  242. if self.use_ema_update:
  243. normalizer = self.ema_update("loss_normalizer", max(num_fgs, 1), 100)
  244. else:
  245. normalizer = num_fgs
  246. # ------------------ Classification loss ------------------
  247. cls_preds = cls_preds.view(-1, self.num_classes)
  248. loss_cls = self.loss_classes(cls_preds, cls_targets)
  249. loss_cls = loss_cls.sum() / normalizer
  250. # ------------------ Regression loss ------------------
  251. box_preds_pos = box_preds.view(-1, 4)[fg_masks]
  252. loss_box = self.loss_bboxes(box_preds_pos, box_targets)
  253. loss_box = loss_box.sum() / normalizer
  254. # ------------------ Distribution focal loss ------------------
  255. ## process anchors
  256. anchors = torch.cat(anchors, dim=0)
  257. anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
  258. ## process stride tensors
  259. strides = torch.cat(outputs['stride_tensor'], dim=0)
  260. strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
  261. ## fg preds
  262. reg_preds_pos = reg_preds.view(-1, 4*self.cfg['reg_max'])[fg_masks]
  263. anchors_pos = anchors[fg_masks]
  264. strides_pos = strides[fg_masks]
  265. ## compute dfl
  266. loss_dfl = self.loss_dfl(reg_preds_pos, box_targets, anchors_pos, strides_pos)
  267. loss_dfl = loss_dfl.sum() / normalizer
  268. # total loss
  269. losses = self.loss_cls_weight * loss_cls + \
  270. self.loss_box_weight * loss_box + \
  271. self.loss_dfl_weight * loss_dfl
  272. loss_dict = dict(
  273. loss_cls = loss_cls,
  274. loss_box = loss_box,
  275. loss_dfl = loss_dfl,
  276. losses = losses
  277. )
  278. return loss_dict
  279. def build_criterion(args, cfg, device, num_classes):
  280. criterion = Criterion(
  281. args=args,
  282. cfg=cfg,
  283. device=device,
  284. num_classes=num_classes
  285. )
  286. return criterion
  287. if __name__ == "__main__":
  288. pass