loss.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. import torch
  2. import torch.nn.functional as F
  3. from utils.box_ops import bbox2dist, get_ious
  4. from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
  5. from .matcher import TaskAlignedAssigner, AlignedSimOTA
  6. class Criterion(object):
  7. def __init__(self, args, cfg, device, num_classes=80):
  8. self.cfg = cfg
  9. self.args = args
  10. self.device = device
  11. self.num_classes = num_classes
  12. self.use_ema_update = cfg['ema_update']
  13. # ---------------- Loss weight ----------------
  14. self.loss_cls_weight = cfg['loss_cls_weight']
  15. self.loss_box_weight = cfg['loss_box_weight']
  16. self.loss_dfl_weight = cfg['loss_dfl_weight']
  17. # ---------------- Matcher ----------------
  18. matcher_config = cfg['matcher']
  19. self.switch_epoch = matcher_config['switch_epoch']
  20. ## TAL assigner
  21. self.tal_matcher = TaskAlignedAssigner(
  22. topk=matcher_config['tal']['topk'],
  23. alpha=matcher_config['tal']['alpha'],
  24. beta=matcher_config['tal']['beta'],
  25. num_classes=num_classes
  26. )
  27. ## SimOTA assigner
  28. self.ota_matcher = AlignedSimOTA(
  29. center_sampling_radius=matcher_config['ota']['center_sampling_radius'],
  30. topk_candidate=matcher_config['ota']['topk_candidate'],
  31. num_classes=num_classes
  32. )
  33. def __call__(self, outputs, targets, epoch=0):
  34. if epoch < self.switch_epoch:
  35. return self.ota_loss(outputs, targets)
  36. else:
  37. return self.tal_loss(outputs, targets)
  38. def ema_update(self, name: str, value, initial_value, momentum=0.9):
  39. if hasattr(self, name):
  40. old = getattr(self, name)
  41. else:
  42. old = initial_value
  43. new = old * momentum + value * (1 - momentum)
  44. setattr(self, name, new)
  45. return new
  46. # ----------------- Loss functions -----------------
  47. def loss_classes(self, pred_cls, gt_score, gt_label=None, vfl=False):
  48. if vfl:
  49. assert gt_label is not None
  50. # compute varifocal loss
  51. alpha, gamma = 0.75, 2.0
  52. focal_weight = alpha * pred_cls.sigmoid().pow(gamma) * (1 - gt_label) + gt_score * gt_label
  53. bce_loss = F.binary_cross_entropy_with_logits(pred_cls, gt_score, reduction='none')
  54. loss_cls = bce_loss * focal_weight
  55. else:
  56. # compute bce loss
  57. loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_score, reduction='none')
  58. return loss_cls
  59. def loss_bboxes(self, pred_box, gt_box, bbox_weight=None):
  60. # regression loss
  61. ious = get_ious(pred_box, gt_box, 'xyxy', 'giou')
  62. loss_box = 1.0 - ious
  63. if bbox_weight is not None:
  64. loss_box *= bbox_weight
  65. return loss_box
  66. def loss_dfl(self, pred_reg, gt_box, anchor, stride, bbox_weight=None):
  67. # rescale coords by stride
  68. gt_box_s = gt_box / stride
  69. anchor_s = anchor / stride
  70. # compute deltas
  71. gt_ltrb_s = bbox2dist(anchor_s, gt_box_s, self.cfg['reg_max'] - 1)
  72. gt_left = gt_ltrb_s.to(torch.long)
  73. gt_right = gt_left + 1
  74. weight_left = gt_right.to(torch.float) - gt_ltrb_s
  75. weight_right = 1 - weight_left
  76. # loss left
  77. loss_left = F.cross_entropy(
  78. pred_reg.view(-1, self.cfg['reg_max']),
  79. gt_left.view(-1),
  80. reduction='none').view(gt_left.shape) * weight_left
  81. # loss right
  82. loss_right = F.cross_entropy(
  83. pred_reg.view(-1, self.cfg['reg_max']),
  84. gt_right.view(-1),
  85. reduction='none').view(gt_left.shape) * weight_right
  86. loss_dfl = (loss_left + loss_right).mean(-1)
  87. if bbox_weight is not None:
  88. loss_dfl *= bbox_weight
  89. return loss_dfl
  90. # ----------------- Loss with TAL assigner -----------------
  91. def tal_loss(self, outputs, targets):
  92. """ Compute loss with TAL assigner """
  93. bs = outputs['pred_cls'][0].shape[0]
  94. device = outputs['pred_cls'][0].device
  95. anchors = torch.cat(outputs['anchors'], dim=0)
  96. num_anchors = anchors.shape[0]
  97. # preds: [B, M, C]
  98. cls_preds = torch.cat(outputs['pred_cls'], dim=1)
  99. reg_preds = torch.cat(outputs['pred_reg'], dim=1)
  100. box_preds = torch.cat(outputs['pred_box'], dim=1)
  101. # --------------- label assignment ---------------
  102. gt_label_targets = []
  103. gt_score_targets = []
  104. gt_bbox_targets = []
  105. fg_masks = []
  106. for batch_idx in range(bs):
  107. tgt_labels = targets[batch_idx]["labels"].to(device)
  108. tgt_bboxes = targets[batch_idx]["boxes"].to(device)
  109. # check target
  110. if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
  111. # There is no valid gt
  112. fg_mask = cls_preds.new_zeros(1, num_anchors).bool() #[1, M,]
  113. gt_label = cls_preds.new_zeros((1, num_anchors,)) #[1, M,]
  114. gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
  115. gt_box = cls_preds.new_zeros((1, num_anchors, 4)) #[1, M, 4]
  116. else:
  117. tgt_labels = tgt_labels[None, :, None] # [1, Mp, 1]
  118. tgt_bboxes = tgt_bboxes[None] # [1, Mp, 4]
  119. (
  120. gt_label, #[1, M]
  121. gt_box, #[1, M, 4]
  122. gt_score, #[1, M, C]
  123. fg_mask, #[1, M,]
  124. _
  125. ) = self.tal_matcher(
  126. pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(),
  127. pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
  128. anc_points = anchors,
  129. gt_labels = tgt_labels,
  130. gt_bboxes = tgt_bboxes
  131. )
  132. gt_label_targets.append(gt_label)
  133. gt_score_targets.append(gt_score)
  134. gt_bbox_targets.append(gt_box)
  135. fg_masks.append(fg_mask)
  136. # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
  137. fg_masks = torch.cat(fg_masks, 0).view(-1) # [BM,]
  138. gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes) # [BM, C]
  139. gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4) # [BM, 4]
  140. gt_label_targets = torch.cat(gt_label_targets, 0).view(-1) # [BM,]
  141. gt_label_targets = torch.where(fg_masks > 0, gt_label_targets, torch.full_like(gt_label_targets, self.num_classes))
  142. gt_labels_one_hot = F.one_hot(gt_label_targets.long(), self.num_classes + 1)[..., :-1]
  143. bbox_weight = gt_score_targets[fg_masks].sum(-1)
  144. num_fgs = max(gt_score_targets.sum(), 1)
  145. # average loss normalizer across all the GPUs
  146. if is_dist_avail_and_initialized():
  147. torch.distributed.all_reduce(num_fgs)
  148. num_fgs = max(num_fgs / get_world_size(), 1.0)
  149. # update loss normalizer with EMA
  150. if self.use_ema_update:
  151. normalizer = self.ema_update("loss_normalizer", max(num_fgs, 1), 100)
  152. else:
  153. normalizer = num_fgs
  154. # ------------------ Classification loss ------------------
  155. cls_preds = cls_preds.view(-1, self.num_classes)
  156. loss_cls = self.loss_classes(cls_preds, gt_score_targets, gt_labels_one_hot, vfl=False)
  157. loss_cls = loss_cls.sum() / normalizer
  158. # ------------------ Regression loss ------------------
  159. box_preds_pos = box_preds.view(-1, 4)[fg_masks]
  160. box_targets_pos = gt_bbox_targets[fg_masks]
  161. loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos, bbox_weight)
  162. loss_box = loss_box.sum() / normalizer
  163. # ------------------ Distribution focal loss ------------------
  164. ## process anchors
  165. anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
  166. ## process stride tensors
  167. strides = torch.cat(outputs['stride_tensor'], dim=0)
  168. strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
  169. ## fg preds
  170. reg_preds_pos = reg_preds.view(-1, 4*self.cfg['reg_max'])[fg_masks]
  171. anchors_pos = anchors[fg_masks]
  172. strides_pos = strides[fg_masks]
  173. ## compute dfl
  174. loss_dfl = self.loss_dfl(reg_preds_pos, box_targets_pos, anchors_pos, strides_pos, bbox_weight)
  175. loss_dfl = loss_dfl.sum() / normalizer
  176. # total loss
  177. losses = self.loss_cls_weight * loss_cls + \
  178. self.loss_box_weight * loss_box + \
  179. self.loss_dfl_weight * loss_dfl
  180. loss_dict = dict(
  181. loss_cls = loss_cls,
  182. loss_box = loss_box,
  183. loss_dfl = loss_dfl,
  184. losses = losses
  185. )
  186. return loss_dict
  187. # ----------------- Loss with SimOTA assigner -----------------
  188. def ota_loss(self, outputs, targets):
  189. """ Compute loss with SimOTA assigner """
  190. bs = outputs['pred_cls'][0].shape[0]
  191. device = outputs['pred_cls'][0].device
  192. fpn_strides = outputs['strides']
  193. anchors = outputs['anchors']
  194. num_anchors = sum([ab.shape[0] for ab in anchors])
  195. # preds: [B, M, C]
  196. cls_preds = torch.cat(outputs['pred_cls'], dim=1)
  197. reg_preds = torch.cat(outputs['pred_reg'], dim=1)
  198. box_preds = torch.cat(outputs['pred_box'], dim=1)
  199. # --------------- label assignment ---------------
  200. cls_targets = []
  201. box_targets = []
  202. fg_masks = []
  203. for batch_idx in range(bs):
  204. tgt_labels = targets[batch_idx]["labels"].to(device)
  205. tgt_bboxes = targets[batch_idx]["boxes"].to(device)
  206. # check target
  207. if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
  208. # There is no valid gt
  209. cls_target = cls_preds.new_zeros((num_anchors, self.num_classes))
  210. box_target = cls_preds.new_zeros((0, 4))
  211. fg_mask = cls_preds.new_zeros(num_anchors).bool()
  212. else:
  213. (
  214. fg_mask,
  215. assigned_labels,
  216. assigned_ious,
  217. assigned_indexs
  218. ) = self.ota_matcher(
  219. fpn_strides = fpn_strides,
  220. anchors = anchors,
  221. pred_cls = cls_preds[batch_idx],
  222. pred_box = box_preds[batch_idx],
  223. tgt_labels = tgt_labels,
  224. tgt_bboxes = tgt_bboxes
  225. )
  226. # prepare cls targets
  227. assigned_labels = F.one_hot(assigned_labels.long(), self.num_classes)
  228. assigned_labels = assigned_labels * assigned_ious.unsqueeze(-1)
  229. cls_target = assigned_labels.new_zeros((num_anchors, self.num_classes))
  230. cls_target[fg_mask] = assigned_labels
  231. # prepare box targets
  232. box_target = tgt_bboxes[assigned_indexs]
  233. cls_targets.append(cls_target)
  234. box_targets.append(box_target)
  235. fg_masks.append(fg_mask)
  236. cls_targets = torch.cat(cls_targets, 0)
  237. box_targets = torch.cat(box_targets, 0)
  238. fg_masks = torch.cat(fg_masks, 0)
  239. num_fgs = fg_masks.sum()
  240. # average loss normalizer across all the GPUs
  241. if is_dist_avail_and_initialized():
  242. torch.distributed.all_reduce(num_fgs)
  243. num_fgs = (num_fgs / get_world_size()).clamp(1.0)
  244. # update loss normalizer with EMA
  245. if self.use_ema_update:
  246. normalizer = self.ema_update("loss_normalizer", max(num_fgs, 1), 100)
  247. else:
  248. normalizer = num_fgs
  249. # ------------------ Classification loss ------------------
  250. cls_preds = cls_preds.view(-1, self.num_classes)
  251. loss_cls = self.loss_classes(cls_preds, cls_targets)
  252. loss_cls = loss_cls.sum() / normalizer
  253. # ------------------ Regression loss ------------------
  254. box_preds_pos = box_preds.view(-1, 4)[fg_masks]
  255. loss_box = self.loss_bboxes(box_preds_pos, box_targets)
  256. loss_box = loss_box.sum() / normalizer
  257. # ------------------ Distribution focal loss ------------------
  258. ## process anchors
  259. anchors = torch.cat(anchors, dim=0)
  260. anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
  261. ## process stride tensors
  262. strides = torch.cat(outputs['stride_tensor'], dim=0)
  263. strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
  264. ## fg preds
  265. reg_preds_pos = reg_preds.view(-1, 4*self.cfg['reg_max'])[fg_masks]
  266. anchors_pos = anchors[fg_masks]
  267. strides_pos = strides[fg_masks]
  268. ## compute dfl
  269. loss_dfl = self.loss_dfl(reg_preds_pos, box_targets, anchors_pos, strides_pos)
  270. loss_dfl = loss_dfl.sum() / normalizer
  271. # total loss
  272. losses = self.loss_cls_weight * loss_cls + \
  273. self.loss_box_weight * loss_box + \
  274. self.loss_dfl_weight * loss_dfl
  275. loss_dict = dict(
  276. loss_cls = loss_cls,
  277. loss_box = loss_box,
  278. loss_dfl = loss_dfl,
  279. losses = losses
  280. )
  281. return loss_dict
  282. def build_criterion(args, cfg, device, num_classes):
  283. criterion = Criterion(
  284. args=args,
  285. cfg=cfg,
  286. device=device,
  287. num_classes=num_classes
  288. )
  289. return criterion
  290. if __name__ == "__main__":
  291. pass