loss.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from utils.box_ops import *
  5. from utils.misc import sigmoid_focal_loss
  6. from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
  7. from .matcher import UniformMatcher
  8. class SetCriterion(object):
  9. """
  10. This code referenced to https://github.com/megvii-model/YOLOF/blob/main/playground/detection/coco/yolof/yolof_base/yolof.py
  11. """
  12. def __init__(self, cfg):
  13. # ------------- Basic parameters -------------
  14. self.cfg = cfg
  15. self.num_classes = cfg.num_classes
  16. # ------------- Focal loss -------------
  17. self.alpha = cfg.focal_loss_alpha
  18. self.gamma = cfg.focal_loss_gamma
  19. # ------------- Loss weight -------------
  20. self.weight_dict = {'loss_cls': cfg.loss_cls,
  21. 'loss_reg': cfg.loss_reg}
  22. # ------------- Matcher -------------
  23. self.ignore_thresh = cfg.ignore_thresh
  24. self.match_iou_weight = cfg.match_iou_thresh
  25. self.matcher = UniformMatcher(cfg.match_topk_candidates)
  26. def loss_labels(self, pred_cls, tgt_cls, num_boxes):
  27. """
  28. pred_cls: (Tensor) [N, C]
  29. tgt_cls: (Tensor) [N, C]
  30. """
  31. # cls loss: [V, C]
  32. loss_cls = sigmoid_focal_loss(pred_cls, tgt_cls, self.alpha, self.gamma)
  33. return loss_cls.sum() / num_boxes
  34. def loss_bboxes(self, pred_box, tgt_box, num_boxes):
  35. """
  36. pred_box: (Tensor) [N, 4]
  37. tgt_box: (Tensor) [N, 4]
  38. """
  39. # giou
  40. pred_giou = generalized_box_iou(pred_box, tgt_box) # [N, M]
  41. # giou loss
  42. loss_reg = 1. - torch.diag(pred_giou)
  43. return loss_reg.sum() / num_boxes
  44. def __call__(self, outputs, targets):
  45. """
  46. outputs['pred_cls']: (Tensor) [B, M, C]
  47. outputs['pred_box']: (Tensor) [B, M, 4]
  48. targets: (List) [dict{'boxes': [...],
  49. 'labels': [...],
  50. 'orig_size': ...}, ...]
  51. """
  52. # -------------------- Pre-process --------------------
  53. pred_box = outputs['pred_box']
  54. pred_cls = outputs['pred_cls'].reshape(-1, self.num_classes)
  55. anchor_boxes = outputs['anchors']
  56. device = pred_box.device
  57. bs = len(targets)
  58. # -------------------- Label assignment --------------------
  59. indices = self.matcher(pred_box, anchor_boxes, targets)
  60. # [M, 4] -> [1, M, 4] -> [B, M, 4]
  61. anchor_boxes = box_cxcywh_to_xyxy(anchor_boxes)
  62. anchor_boxes = anchor_boxes[None].repeat(bs, 1, 1)
  63. ious = []
  64. pos_ious = []
  65. for i in range(bs):
  66. src_idx, tgt_idx = indices[i]
  67. # iou between predbox and tgt box
  68. iou, _ = box_iou(pred_box[i, ...], (targets[i]['boxes']).clone().to(device))
  69. if iou.numel() == 0:
  70. max_iou = iou.new_full((iou.size(0),), 0)
  71. else:
  72. max_iou = iou.max(dim=1)[0]
  73. # iou between anchorbox and tgt box
  74. a_iou, _ = box_iou(anchor_boxes[i], (targets[i]['boxes']).clone().to(device))
  75. if a_iou.numel() == 0:
  76. pos_iou = a_iou.new_full((0,), 0)
  77. else:
  78. pos_iou = a_iou[src_idx, tgt_idx]
  79. ious.append(max_iou)
  80. pos_ious.append(pos_iou)
  81. ious = torch.cat(ious)
  82. ignore_idx = ious > self.ignore_thresh
  83. pos_ious = torch.cat(pos_ious)
  84. pos_ignore_idx = pos_ious < self.match_iou_weight
  85. src_idx = torch.cat(
  86. [src + idx * anchor_boxes[0].shape[0] for idx, (src, _) in
  87. enumerate(indices)])
  88. # [BM,]
  89. gt_cls = torch.full(pred_cls.shape[:1],
  90. self.num_classes,
  91. dtype=torch.int64,
  92. device=device)
  93. gt_cls[ignore_idx] = -1
  94. tgt_cls_o = torch.cat([t['labels'][J] for t, (_, J) in zip(targets, indices)])
  95. tgt_cls_o[pos_ignore_idx] = -1
  96. gt_cls[src_idx] = tgt_cls_o.to(device)
  97. fg_mask = (gt_cls >= 0) & (gt_cls != self.num_classes)
  98. num_fgs = fg_mask.sum()
  99. if is_dist_avail_and_initialized():
  100. torch.distributed.all_reduce(num_fgs)
  101. num_fgs = torch.clamp(num_fgs / get_world_size(), min=1).item()
  102. # -------------------- Classification loss --------------------
  103. gt_cls_target = torch.zeros_like(pred_cls)
  104. gt_cls_target[fg_mask, gt_cls[fg_mask]] = 1
  105. loss_labels = self.loss_labels(pred_cls, gt_cls_target, num_fgs)
  106. # -------------------- Regression loss --------------------
  107. tgt_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0).to(device)
  108. tgt_boxes = tgt_boxes[~pos_ignore_idx]
  109. matched_pred_box = pred_box.reshape(-1, 4)[src_idx[~pos_ignore_idx.cpu()]]
  110. loss_bboxes = self.loss_bboxes(matched_pred_box, tgt_boxes, num_fgs)
  111. total_loss = loss_labels * self.weight_dict["loss_cls"] + \
  112. loss_bboxes * self.weight_dict["loss_reg"]
  113. loss_dict = dict(
  114. loss_cls = loss_labels,
  115. loss_reg = loss_bboxes,
  116. losses = total_loss,
  117. )
  118. return loss_dict
  119. if __name__ == "__main__":
  120. pass