criterion.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
  5. from utils.misc import sigmoid_focal_loss
  6. from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
  7. from .matcher import RetinaNetMatcher
  8. class Criterion(nn.Module):
  9. def __init__(self, cfg, num_classes=80):
  10. super().__init__()
  11. # ------------- Basic parameters -------------
  12. self.cfg = cfg
  13. self.num_classes = num_classes
  14. # ------------- Focal loss -------------
  15. self.alpha = cfg['focal_loss_alpha']
  16. self.gamma = cfg['focal_loss_gamma']
  17. # ------------- Loss weight -------------
  18. self.weight_dict = {'loss_cls': cfg['loss_cls_weight'],
  19. 'loss_reg': cfg['loss_reg_weight']}
  20. # ------------- Matcher -------------
  21. self.matcher_cfg = cfg['matcher_hpy']
  22. self.matcher = RetinaNetMatcher(num_classes,
  23. iou_threshold=self.matcher_cfg['iou_thresh'],
  24. iou_labels=self.matcher_cfg['iou_labels'],
  25. allow_low_quality_matches=self.matcher_cfg['allow_low_quality_matches']
  26. )
  27. def loss_labels(self, pred_cls, tgt_cls, num_boxes):
  28. """
  29. pred_cls: (Tensor) [N, C]
  30. tgt_cls: (Tensor) [N, C]
  31. """
  32. # cls loss: [V, C]
  33. loss_cls = sigmoid_focal_loss(pred_cls, tgt_cls, self.alpha, self.gamma)
  34. return loss_cls.sum() / num_boxes
  35. def loss_bboxes(self, pred_reg=None, pred_box=None, tgt_box=None, anchors=None, num_boxes=1, use_giou=False):
  36. """
  37. pred_reg: (Tensor) [Nq, 4]
  38. tgt_box: (Tensor) [Nq, 4]
  39. anchors: (Tensor) [Nq, 4]
  40. """
  41. # GIoU loss
  42. if use_giou:
  43. pred_giou = generalized_box_iou(pred_box, tgt_box) # [N, M]
  44. loss_reg = 1. - torch.diag(pred_giou)
  45. # L1 loss
  46. else:
  47. # xyxy -> cxcy&bwbh
  48. tgt_cxcy = (tgt_box[..., :2] + tgt_box[..., 2:]) * 0.5
  49. tgt_bwbh = tgt_box[..., 2:] - tgt_box[..., :2]
  50. # encode gt box
  51. tgt_offsets = (tgt_cxcy - anchors[..., :2]) / anchors[..., 2:]
  52. tgt_sizes = torch.log(tgt_bwbh / anchors[..., 2:])
  53. tgt_box_encode = torch.cat([tgt_offsets, tgt_sizes], dim=-1)
  54. # compute l1 loss
  55. loss_reg = F.l1_loss(pred_reg, tgt_box_encode, reduction='none')
  56. return loss_reg.sum() / num_boxes
  57. def forward(self, outputs, targets):
  58. """
  59. outputs['pred_cls']: (Tensor) [B, M, C]
  60. outputs['pred_reg']: (Tensor) [B, M, 4]
  61. outputs['strides']: (List) [8, 16, 32, ...] stride of the model output
  62. targets: (List) [dict{'boxes': [...],
  63. 'labels': [...],
  64. 'orig_size': ...}, ...]
  65. anchors: (Tensor) [M, 4]
  66. """
  67. # -------------------- Pre-process --------------------
  68. cls_preds = torch.cat(outputs['pred_cls'], dim=1).view(-1, self.num_classes)
  69. reg_preds = torch.cat(outputs['pred_reg'], dim=1).view(-1, 4)
  70. box_preds = torch.cat(outputs['pred_box'], dim=1).view(-1, 4)
  71. masks = ~torch.cat(outputs['mask'], dim=1).view(-1)
  72. B = len(targets)
  73. # process anchor boxes
  74. anchor_boxes = torch.cat(outputs['anchors'])
  75. anchor_boxes = anchor_boxes[None].repeat(B, 1, 1)
  76. anchor_boxes_xyxy = box_cxcywh_to_xyxy(anchor_boxes)
  77. # -------------------- Label Assignment --------------------
  78. tgt_classes, tgt_boxes = self.matcher(anchor_boxes_xyxy, targets)
  79. tgt_classes = tgt_classes.flatten()
  80. tgt_boxes = tgt_boxes.view(-1, 4)
  81. del anchor_boxes_xyxy
  82. foreground_idxs = (tgt_classes >= 0) & (tgt_classes != self.num_classes)
  83. valid_idxs = (tgt_classes >= 0) & masks
  84. num_foreground = foreground_idxs.sum()
  85. if is_dist_avail_and_initialized():
  86. torch.distributed.all_reduce(num_foreground)
  87. num_foreground = torch.clamp(num_foreground / get_world_size(), min=1).item()
  88. # -------------------- Classification loss --------------------
  89. gt_cls_target = torch.zeros_like(cls_preds)
  90. gt_cls_target[foreground_idxs, tgt_classes[foreground_idxs]] = 1
  91. loss_labels = self.loss_labels(
  92. cls_preds[valid_idxs], gt_cls_target[valid_idxs], num_foreground)
  93. # -------------------- Regression loss --------------------
  94. if self.cfg['use_giou_loss']:
  95. box_preds_pos = box_preds[foreground_idxs]
  96. tgt_boxes_pos = tgt_boxes[foreground_idxs].to(reg_preds.device)
  97. loss_bboxes = self.loss_bboxes(
  98. pred_box=box_preds_pos, tgt_box=tgt_boxes_pos, num_boxes=num_foreground, use_giou=self.cfg['use_giou_loss'])
  99. else:
  100. reg_preds_pos = reg_preds[foreground_idxs]
  101. tgt_boxes_pos = tgt_boxes[foreground_idxs].to(reg_preds.device)
  102. anchors_pos = anchor_boxes.view(-1, 4)[foreground_idxs]
  103. loss_bboxes = self.loss_bboxes(
  104. pred_reg=reg_preds_pos, tgt_box=tgt_boxes_pos, anchors=anchors_pos, num_boxes=num_foreground, use_giou=self.cfg['use_giou_loss'])
  105. loss_dict = dict(
  106. loss_cls = loss_labels,
  107. loss_reg = loss_bboxes,
  108. )
  109. return loss_dict
  110. # build criterion
  111. def build_criterion(cfg, num_classes=80):
  112. criterion = Criterion(cfg=cfg, num_classes=num_classes)
  113. return criterion
  114. if __name__ == "__main__":
  115. pass