criterion.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # ---------------------------------------------------------------------
  2. # Copyright (c) Megvii Inc. All rights reserved.
  3. # ---------------------------------------------------------------------
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from utils.box_ops import *
  8. from utils.misc import sigmoid_focal_loss
  9. from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
  10. from .matcher import UniformMatcher
  11. class SetCriterion(nn.Module):
  12. """
  13. This code referenced to https://github.com/megvii-model/YOLOF/blob/main/playground/detection/coco/yolof/yolof_base/yolof.py
  14. """
  15. def __init__(self, cfg):
  16. super().__init__()
  17. # ------------- Basic parameters -------------
  18. self.cfg = cfg
  19. self.num_classes = cfg.num_classes
  20. # ------------- Focal loss -------------
  21. self.alpha = cfg.focal_loss_alpha
  22. self.gamma = cfg.focal_loss_gamma
  23. # ------------- Loss weight -------------
  24. self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
  25. 'loss_reg': cfg.loss_reg_weight}
  26. # ------------- Matcher -------------
  27. self.matcher_cfg = cfg.matcher_hpy
  28. self.matcher = UniformMatcher(self.matcher_cfg['topk_candidates'])
  29. def loss_labels(self, pred_cls, tgt_cls, num_boxes):
  30. """
  31. pred_cls: (Tensor) [N, C]
  32. tgt_cls: (Tensor) [N, C]
  33. """
  34. # cls loss: [V, C]
  35. loss_cls = sigmoid_focal_loss(pred_cls, tgt_cls, self.alpha, self.gamma)
  36. return loss_cls.sum() / num_boxes
  37. def loss_bboxes(self, pred_box, tgt_box, num_boxes):
  38. """
  39. pred_box: (Tensor) [N, 4]
  40. tgt_box: (Tensor) [N, 4]
  41. """
  42. # giou
  43. pred_giou = generalized_box_iou(pred_box, tgt_box) # [N, M]
  44. # giou loss
  45. loss_reg = 1. - torch.diag(pred_giou)
  46. return loss_reg.sum() / num_boxes
  47. def forward(self, outputs, targets):
  48. """
  49. outputs['pred_cls']: (Tensor) [B, M, C]
  50. outputs['pred_box']: (Tensor) [B, M, 4]
  51. targets: (List) [dict{'boxes': [...],
  52. 'labels': [...],
  53. 'orig_size': ...}, ...]
  54. """
  55. # -------------------- Pre-process --------------------
  56. pred_box = outputs['pred_box']
  57. pred_cls = outputs['pred_cls'].reshape(-1, self.num_classes)
  58. anchor_boxes = outputs['anchors']
  59. masks = ~outputs['mask']
  60. device = pred_box.device
  61. B = len(targets)
  62. # -------------------- Label assignment --------------------
  63. indices = self.matcher(pred_box, anchor_boxes, targets)
  64. # [M, 4] -> [1, M, 4] -> [B, M, 4]
  65. anchor_boxes = box_cxcywh_to_xyxy(anchor_boxes)
  66. anchor_boxes = anchor_boxes[None].repeat(B, 1, 1)
  67. ious = []
  68. pos_ious = []
  69. for i in range(B):
  70. src_idx, tgt_idx = indices[i]
  71. # iou between predbox and tgt box
  72. iou, _ = box_iou(pred_box[i, ...], (targets[i]['boxes']).clone())
  73. if iou.numel() == 0:
  74. max_iou = iou.new_full((iou.size(0),), 0)
  75. else:
  76. max_iou = iou.max(dim=1)[0]
  77. # iou between anchorbox and tgt box
  78. a_iou, _ = box_iou(anchor_boxes[i], (targets[i]['boxes']).clone())
  79. if a_iou.numel() == 0:
  80. pos_iou = a_iou.new_full((0,), 0)
  81. else:
  82. pos_iou = a_iou[src_idx, tgt_idx]
  83. ious.append(max_iou)
  84. pos_ious.append(pos_iou)
  85. ious = torch.cat(ious)
  86. ignore_idx = ious > self.matcher_cfg['ignore_thresh']
  87. pos_ious = torch.cat(pos_ious)
  88. pos_ignore_idx = pos_ious < self.matcher_cfg['iou_thresh']
  89. src_idx = torch.cat(
  90. [src + idx * anchor_boxes[0].shape[0] for idx, (src, _) in
  91. enumerate(indices)])
  92. # [BM,]
  93. gt_cls = torch.full(pred_cls.shape[:1],
  94. self.num_classes,
  95. dtype=torch.int64,
  96. device=device)
  97. gt_cls[ignore_idx] = -1
  98. tgt_cls_o = torch.cat([t['labels'][J] for t, (_, J) in zip(targets, indices)])
  99. tgt_cls_o[pos_ignore_idx] = -1
  100. gt_cls[src_idx] = tgt_cls_o.to(device)
  101. foreground_idxs = (gt_cls >= 0) & (gt_cls != self.num_classes)
  102. num_foreground = foreground_idxs.sum()
  103. if is_dist_avail_and_initialized():
  104. torch.distributed.all_reduce(num_foreground)
  105. num_foreground = torch.clamp(num_foreground / get_world_size(), min=1).item()
  106. # -------------------- Classification loss --------------------
  107. gt_cls_target = torch.zeros_like(pred_cls)
  108. gt_cls_target[foreground_idxs, gt_cls[foreground_idxs]] = 1
  109. valid_idxs = (gt_cls >= 0) & masks
  110. loss_labels = self.loss_labels(pred_cls[valid_idxs], gt_cls_target[valid_idxs], num_foreground)
  111. # -------------------- Regression loss --------------------
  112. tgt_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0).to(device)
  113. tgt_boxes = tgt_boxes[~pos_ignore_idx]
  114. matched_pred_box = pred_box.reshape(-1, 4)[src_idx[~pos_ignore_idx.cpu()]]
  115. loss_bboxes = self.loss_bboxes(matched_pred_box, tgt_boxes, num_foreground)
  116. loss_dict = dict(
  117. loss_cls = loss_labels,
  118. loss_reg = loss_bboxes,
  119. )
  120. return loss_dict
  121. if __name__ == "__main__":
  122. pass