| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
- from utils.misc import sigmoid_focal_loss
- from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
- from .matcher import RetinaNetMatcher
- class Criterion(nn.Module):
- def __init__(self, cfg, num_classes=80):
- super().__init__()
- # ------------- Basic parameters -------------
- self.cfg = cfg
- self.num_classes = num_classes
- # ------------- Focal loss -------------
- self.alpha = cfg['focal_loss_alpha']
- self.gamma = cfg['focal_loss_gamma']
- # ------------- Loss weight -------------
- self.weight_dict = {'loss_cls': cfg['loss_cls_weight'],
- 'loss_reg': cfg['loss_reg_weight']}
- # ------------- Matcher -------------
- self.matcher_cfg = cfg['matcher_hpy']
- self.matcher = RetinaNetMatcher(num_classes,
- iou_threshold=self.matcher_cfg['iou_thresh'],
- iou_labels=self.matcher_cfg['iou_labels'],
- allow_low_quality_matches=self.matcher_cfg['allow_low_quality_matches']
- )
- def loss_labels(self, pred_cls, tgt_cls, num_boxes):
- """
- pred_cls: (Tensor) [N, C]
- tgt_cls: (Tensor) [N, C]
- """
- # cls loss: [V, C]
- loss_cls = sigmoid_focal_loss(pred_cls, tgt_cls, self.alpha, self.gamma)
- return loss_cls.sum() / num_boxes
- def loss_bboxes(self, pred_reg=None, pred_box=None, tgt_box=None, anchors=None, num_boxes=1, use_giou=False):
- """
- pred_reg: (Tensor) [Nq, 4]
- tgt_box: (Tensor) [Nq, 4]
- anchors: (Tensor) [Nq, 4]
- """
- # GIoU loss
- if use_giou:
- pred_giou = generalized_box_iou(pred_box, tgt_box) # [N, M]
- loss_reg = 1. - torch.diag(pred_giou)
-
- # L1 loss
- else:
- # xyxy -> cxcy&bwbh
- tgt_cxcy = (tgt_box[..., :2] + tgt_box[..., 2:]) * 0.5
- tgt_bwbh = tgt_box[..., 2:] - tgt_box[..., :2]
- # encode gt box
- tgt_offsets = (tgt_cxcy - anchors[..., :2]) / anchors[..., 2:]
- tgt_sizes = torch.log(tgt_bwbh / anchors[..., 2:])
- tgt_box_encode = torch.cat([tgt_offsets, tgt_sizes], dim=-1)
- # compute l1 loss
- loss_reg = F.l1_loss(pred_reg, tgt_box_encode, reduction='none')
- return loss_reg.sum() / num_boxes
- def forward(self, outputs, targets):
- """
- outputs['pred_cls']: (Tensor) [B, M, C]
- outputs['pred_reg']: (Tensor) [B, M, 4]
- outputs['strides']: (List) [8, 16, 32, ...] stride of the model output
- targets: (List) [dict{'boxes': [...],
- 'labels': [...],
- 'orig_size': ...}, ...]
- anchors: (Tensor) [M, 4]
- """
- # -------------------- Pre-process --------------------
- cls_preds = torch.cat(outputs['pred_cls'], dim=1).view(-1, self.num_classes)
- reg_preds = torch.cat(outputs['pred_reg'], dim=1).view(-1, 4)
- box_preds = torch.cat(outputs['pred_box'], dim=1).view(-1, 4)
- masks = ~torch.cat(outputs['mask'], dim=1).view(-1)
- B = len(targets)
-
- # process anchor boxes
- anchor_boxes = torch.cat(outputs['anchors'])
- anchor_boxes = anchor_boxes[None].repeat(B, 1, 1)
- anchor_boxes_xyxy = box_cxcywh_to_xyxy(anchor_boxes)
- # -------------------- Label Assignment --------------------
- tgt_classes, tgt_boxes = self.matcher(anchor_boxes_xyxy, targets)
- tgt_classes = tgt_classes.flatten()
- tgt_boxes = tgt_boxes.view(-1, 4)
- del anchor_boxes_xyxy
- foreground_idxs = (tgt_classes >= 0) & (tgt_classes != self.num_classes)
- valid_idxs = (tgt_classes >= 0) & masks
- num_foreground = foreground_idxs.sum()
- if is_dist_avail_and_initialized():
- torch.distributed.all_reduce(num_foreground)
- num_foreground = torch.clamp(num_foreground / get_world_size(), min=1).item()
- # -------------------- Classification loss --------------------
- gt_cls_target = torch.zeros_like(cls_preds)
- gt_cls_target[foreground_idxs, tgt_classes[foreground_idxs]] = 1
- loss_labels = self.loss_labels(
- cls_preds[valid_idxs], gt_cls_target[valid_idxs], num_foreground)
- # -------------------- Regression loss --------------------
- if self.cfg['use_giou_loss']:
- box_preds_pos = box_preds[foreground_idxs]
- tgt_boxes_pos = tgt_boxes[foreground_idxs].to(reg_preds.device)
- loss_bboxes = self.loss_bboxes(
- pred_box=box_preds_pos, tgt_box=tgt_boxes_pos, num_boxes=num_foreground, use_giou=self.cfg['use_giou_loss'])
- else:
- reg_preds_pos = reg_preds[foreground_idxs]
- tgt_boxes_pos = tgt_boxes[foreground_idxs].to(reg_preds.device)
- anchors_pos = anchor_boxes.view(-1, 4)[foreground_idxs]
- loss_bboxes = self.loss_bboxes(
- pred_reg=reg_preds_pos, tgt_box=tgt_boxes_pos, anchors=anchors_pos, num_boxes=num_foreground, use_giou=self.cfg['use_giou_loss'])
- loss_dict = dict(
- loss_cls = loss_labels,
- loss_reg = loss_bboxes,
- )
- return loss_dict
-
- # build criterion
- def build_criterion(cfg, num_classes=80):
- criterion = Criterion(cfg=cfg, num_classes=num_classes)
- return criterion
- if __name__ == "__main__":
- pass
|