| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from utils.box_ops import *
- from utils.misc import sigmoid_focal_loss
- from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
- from .matcher import UniformMatcher
- class SetCriterion(nn.Module):
- """
- This code referenced to https://github.com/megvii-model/YOLOF/blob/main/playground/detection/coco/yolof/yolof_base/yolof.py
- """
- def __init__(self, cfg):
- super().__init__()
- # ------------- Basic parameters -------------
- self.cfg = cfg
- self.num_classes = cfg.num_classes
- # ------------- Focal loss -------------
- self.alpha = cfg.focal_loss_alpha
- self.gamma = cfg.focal_loss_gamma
- # ------------- Loss weight -------------
- self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
- 'loss_reg': cfg.loss_reg_weight}
- # ------------- Matcher -------------
- self.matcher_cfg = cfg.matcher_hpy
- self.matcher = UniformMatcher(self.matcher_cfg['topk_candidates'])
- def loss_labels(self, pred_cls, tgt_cls, num_boxes):
- """
- pred_cls: (Tensor) [N, C]
- tgt_cls: (Tensor) [N, C]
- """
- # cls loss: [V, C]
- loss_cls = sigmoid_focal_loss(pred_cls, tgt_cls, self.alpha, self.gamma)
- return loss_cls.sum() / num_boxes
- def loss_bboxes(self, pred_box, tgt_box, num_boxes):
- """
- pred_box: (Tensor) [N, 4]
- tgt_box: (Tensor) [N, 4]
- """
- # giou
- pred_giou = generalized_box_iou(pred_box, tgt_box) # [N, M]
- # giou loss
- loss_reg = 1. - torch.diag(pred_giou)
- return loss_reg.sum() / num_boxes
- def forward(self, outputs, targets):
- """
- outputs['pred_cls']: (Tensor) [B, M, C]
- outputs['pred_box']: (Tensor) [B, M, 4]
- targets: (List) [dict{'boxes': [...],
- 'labels': [...],
- 'orig_size': ...}, ...]
- """
- # -------------------- Pre-process --------------------
- pred_box = outputs['pred_box']
- pred_cls = outputs['pred_cls'].reshape(-1, self.num_classes)
- anchor_boxes = outputs['anchors']
- masks = ~outputs['mask']
- device = pred_box.device
- B = len(targets)
- # -------------------- Label assignment --------------------
- indices = self.matcher(pred_box, anchor_boxes, targets)
- # [M, 4] -> [1, M, 4] -> [B, M, 4]
- anchor_boxes = box_cxcywh_to_xyxy(anchor_boxes)
- anchor_boxes = anchor_boxes[None].repeat(B, 1, 1)
- ious = []
- pos_ious = []
- for i in range(B):
- src_idx, tgt_idx = indices[i]
- # iou between predbox and tgt box
- iou, _ = box_iou(pred_box[i, ...], (targets[i]['boxes']).clone())
- if iou.numel() == 0:
- max_iou = iou.new_full((iou.size(0),), 0)
- else:
- max_iou = iou.max(dim=1)[0]
- # iou between anchorbox and tgt box
- a_iou, _ = box_iou(anchor_boxes[i], (targets[i]['boxes']).clone())
- if a_iou.numel() == 0:
- pos_iou = a_iou.new_full((0,), 0)
- else:
- pos_iou = a_iou[src_idx, tgt_idx]
- ious.append(max_iou)
- pos_ious.append(pos_iou)
- ious = torch.cat(ious)
- ignore_idx = ious > self.matcher_cfg['ignore_thresh']
- pos_ious = torch.cat(pos_ious)
- pos_ignore_idx = pos_ious < self.matcher_cfg['iou_thresh']
- src_idx = torch.cat(
- [src + idx * anchor_boxes[0].shape[0] for idx, (src, _) in
- enumerate(indices)])
- # [BM,]
- gt_cls = torch.full(pred_cls.shape[:1],
- self.num_classes,
- dtype=torch.int64,
- device=device)
- gt_cls[ignore_idx] = -1
- tgt_cls_o = torch.cat([t['labels'][J] for t, (_, J) in zip(targets, indices)])
- tgt_cls_o[pos_ignore_idx] = -1
- gt_cls[src_idx] = tgt_cls_o.to(device)
- foreground_idxs = (gt_cls >= 0) & (gt_cls != self.num_classes)
- num_foreground = foreground_idxs.sum()
- if is_dist_avail_and_initialized():
- torch.distributed.all_reduce(num_foreground)
- num_foreground = torch.clamp(num_foreground / get_world_size(), min=1).item()
- # -------------------- Classification loss --------------------
- gt_cls_target = torch.zeros_like(pred_cls)
- gt_cls_target[foreground_idxs, gt_cls[foreground_idxs]] = 1
- valid_idxs = (gt_cls >= 0) & masks
- loss_labels = self.loss_labels(pred_cls[valid_idxs], gt_cls_target[valid_idxs], num_foreground)
- # -------------------- Regression loss --------------------
- tgt_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0).to(device)
- tgt_boxes = tgt_boxes[~pos_ignore_idx]
- matched_pred_box = pred_box.reshape(-1, 4)[src_idx[~pos_ignore_idx.cpu()]]
- loss_bboxes = self.loss_bboxes(matched_pred_box, tgt_boxes, num_foreground)
- total_loss = loss_labels * self.weight_dict["loss_cls"] + \
- loss_bboxes * self.weight_dict["loss_reg"]
- loss_dict = dict(
- loss_cls = loss_labels,
- loss_reg = loss_bboxes,
- losses = total_loss,
- )
- return loss_dict
- if __name__ == "__main__":
- pass
|