criterion.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from utils.box_ops import *
  5. from utils.misc import sigmoid_focal_loss
  6. from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
  7. from .matcher import UniformMatcher
  8. class SetCriterion(nn.Module):
  9. """
  10. This code referenced to https://github.com/megvii-model/YOLOF/blob/main/playground/detection/coco/yolof/yolof_base/yolof.py
  11. """
  12. def __init__(self, cfg):
  13. super().__init__()
  14. # ------------- Basic parameters -------------
  15. self.cfg = cfg
  16. self.num_classes = cfg.num_classes
  17. # ------------- Focal loss -------------
  18. self.alpha = cfg.focal_loss_alpha
  19. self.gamma = cfg.focal_loss_gamma
  20. # ------------- Loss weight -------------
  21. self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
  22. 'loss_reg': cfg.loss_reg_weight}
  23. # ------------- Matcher -------------
  24. self.matcher_cfg = cfg.matcher_hpy
  25. self.matcher = UniformMatcher(self.matcher_cfg['topk_candidates'])
  26. def loss_labels(self, pred_cls, tgt_cls, num_boxes):
  27. """
  28. pred_cls: (Tensor) [N, C]
  29. tgt_cls: (Tensor) [N, C]
  30. """
  31. # cls loss: [V, C]
  32. loss_cls = sigmoid_focal_loss(pred_cls, tgt_cls, self.alpha, self.gamma)
  33. return loss_cls.sum() / num_boxes
  34. def loss_bboxes(self, pred_box, tgt_box, num_boxes):
  35. """
  36. pred_box: (Tensor) [N, 4]
  37. tgt_box: (Tensor) [N, 4]
  38. """
  39. # giou
  40. pred_giou = generalized_box_iou(pred_box, tgt_box) # [N, M]
  41. # giou loss
  42. loss_reg = 1. - torch.diag(pred_giou)
  43. return loss_reg.sum() / num_boxes
  44. def forward(self, outputs, targets):
  45. """
  46. outputs['pred_cls']: (Tensor) [B, M, C]
  47. outputs['pred_box']: (Tensor) [B, M, 4]
  48. targets: (List) [dict{'boxes': [...],
  49. 'labels': [...],
  50. 'orig_size': ...}, ...]
  51. """
  52. # -------------------- Pre-process --------------------
  53. pred_box = outputs['pred_box']
  54. pred_cls = outputs['pred_cls'].reshape(-1, self.num_classes)
  55. anchor_boxes = outputs['anchors']
  56. masks = ~outputs['mask']
  57. device = pred_box.device
  58. B = len(targets)
  59. # -------------------- Label assignment --------------------
  60. indices = self.matcher(pred_box, anchor_boxes, targets)
  61. # [M, 4] -> [1, M, 4] -> [B, M, 4]
  62. anchor_boxes = box_cxcywh_to_xyxy(anchor_boxes)
  63. anchor_boxes = anchor_boxes[None].repeat(B, 1, 1)
  64. ious = []
  65. pos_ious = []
  66. for i in range(B):
  67. src_idx, tgt_idx = indices[i]
  68. # iou between predbox and tgt box
  69. iou, _ = box_iou(pred_box[i, ...], (targets[i]['boxes']).clone())
  70. if iou.numel() == 0:
  71. max_iou = iou.new_full((iou.size(0),), 0)
  72. else:
  73. max_iou = iou.max(dim=1)[0]
  74. # iou between anchorbox and tgt box
  75. a_iou, _ = box_iou(anchor_boxes[i], (targets[i]['boxes']).clone())
  76. if a_iou.numel() == 0:
  77. pos_iou = a_iou.new_full((0,), 0)
  78. else:
  79. pos_iou = a_iou[src_idx, tgt_idx]
  80. ious.append(max_iou)
  81. pos_ious.append(pos_iou)
  82. ious = torch.cat(ious)
  83. ignore_idx = ious > self.matcher_cfg['ignore_thresh']
  84. pos_ious = torch.cat(pos_ious)
  85. pos_ignore_idx = pos_ious < self.matcher_cfg['iou_thresh']
  86. src_idx = torch.cat(
  87. [src + idx * anchor_boxes[0].shape[0] for idx, (src, _) in
  88. enumerate(indices)])
  89. # [BM,]
  90. gt_cls = torch.full(pred_cls.shape[:1],
  91. self.num_classes,
  92. dtype=torch.int64,
  93. device=device)
  94. gt_cls[ignore_idx] = -1
  95. tgt_cls_o = torch.cat([t['labels'][J] for t, (_, J) in zip(targets, indices)])
  96. tgt_cls_o[pos_ignore_idx] = -1
  97. gt_cls[src_idx] = tgt_cls_o.to(device)
  98. foreground_idxs = (gt_cls >= 0) & (gt_cls != self.num_classes)
  99. num_foreground = foreground_idxs.sum()
  100. if is_dist_avail_and_initialized():
  101. torch.distributed.all_reduce(num_foreground)
  102. num_foreground = torch.clamp(num_foreground / get_world_size(), min=1).item()
  103. # -------------------- Classification loss --------------------
  104. gt_cls_target = torch.zeros_like(pred_cls)
  105. gt_cls_target[foreground_idxs, gt_cls[foreground_idxs]] = 1
  106. valid_idxs = (gt_cls >= 0) & masks
  107. loss_labels = self.loss_labels(pred_cls[valid_idxs], gt_cls_target[valid_idxs], num_foreground)
  108. # -------------------- Regression loss --------------------
  109. tgt_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0).to(device)
  110. tgt_boxes = tgt_boxes[~pos_ignore_idx]
  111. matched_pred_box = pred_box.reshape(-1, 4)[src_idx[~pos_ignore_idx.cpu()]]
  112. loss_bboxes = self.loss_bboxes(matched_pred_box, tgt_boxes, num_foreground)
  113. total_loss = loss_labels * self.weight_dict["loss_cls"] + \
  114. loss_bboxes * self.weight_dict["loss_reg"]
  115. loss_dict = dict(
  116. loss_cls = loss_labels,
  117. loss_reg = loss_bboxes,
  118. losses = total_loss,
  119. )
  120. return loss_dict
  121. if __name__ == "__main__":
  122. pass