""" reference: https://github.com/facebookresearch/detr/blob/main/models/detr.py by lyuwenyu """ import torch import torch.nn as nn import torch.nn.functional as F from .loss_utils import box_cxcywh_to_xyxy, box_iou, generalized_box_iou from .loss_utils import is_dist_avail_and_initialized, get_world_size from .matcher import HungarianMatcher # --------------- Criterion for RT-DETR --------------- class SetCriterion(nn.Module): def __init__(self, cfg): super().__init__() self.num_classes = cfg.num_classes self.losses = ['labels', 'boxes'] self.alpha = 0.75 # For VFL self.gamma = 2.0 self.matcher = HungarianMatcher(cfg.cost_class, cfg.cost_bbox, cfg.cost_giou, alpha=0.25, gamma=2.0) self.weight_dict = {'loss_cls': cfg.loss_cls, 'loss_box': cfg.loss_box, 'loss_giou': cfg.loss_giou} def loss_labels(self, outputs, targets, indices, num_boxes): "Compute variable focal loss" assert 'pred_boxes' in outputs idx = self._get_src_permutation_idx(indices) # Compute IoU between pred and target src_boxes = outputs['pred_boxes'][idx] target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) ious, _ = box_iou(box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes)) ious = torch.diag(ious).detach() # One-hot class label src_logits = outputs['pred_logits'] target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) target_classes[idx] = target_classes_o target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1] # Iou-aware class label target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype) target_score_o[idx] = ious.to(target_score_o.dtype) target_score = target_score_o.unsqueeze(-1) * target # Compute VFL pred_score = F.sigmoid(src_logits).detach() weight = self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction='none') loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes return {'loss_cls': loss} def loss_boxes(self, outputs, targets, indices, num_boxes): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. """ assert 'pred_boxes' in outputs idx = self._get_src_permutation_idx(indices) src_boxes = outputs['pred_boxes'][idx] target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) losses = {} loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') losses['loss_box'] = loss_bbox.sum() / num_boxes loss_giou = 1 - torch.diag(generalized_box_iou( box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes))) losses['loss_giou'] = loss_giou.sum() / num_boxes return losses def _get_src_permutation_idx(self, indices): # permute predictions following indices batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def _get_tgt_permutation_idx(self, indices): # permute targets following indices batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): loss_map = { 'boxes': self.loss_boxes, 'labels': self.loss_labels, } assert loss in loss_map, f'do you really want to compute {loss} loss?' return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) def forward(self, outputs, targets): outputs_without_aux = {k: v for k, v in outputs.items() if 'aux' not in k} # Retrieve the matching between the outputs of the last layer and the targets indices = self.matcher(outputs_without_aux, targets) # Compute the average number of target boxes accross all nodes, for normalization purposes num_boxes = sum(len(t["labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_boxes) num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() # Compute all the requested losses losses = {} for loss in self.losses: l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes) l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} losses.update(l_dict) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if 'aux_outputs' in outputs: for i, aux_outputs in enumerate(outputs['aux_outputs']): indices = self.matcher(aux_outputs, targets) for loss in self.losses: l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes) l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} l_dict = {k + f'_aux_{i}': v for k, v in l_dict.items()} losses.update(l_dict) # In case of cdn auxiliary losses. For rtdetr if 'dn_aux_outputs' in outputs: assert 'dn_meta' in outputs, '' indices = self.get_cdn_matched_indices(outputs['dn_meta'], targets) num_boxes = num_boxes * outputs['dn_meta']['dn_num_group'] for i, aux_outputs in enumerate(outputs['dn_aux_outputs']): # indices = self.matcher(aux_outputs, targets) for loss in self.losses: l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes) l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} l_dict = {k + f'_dn_{i}': v for k, v in l_dict.items()} losses.update(l_dict) return losses @staticmethod def get_cdn_matched_indices(dn_meta, targets): '''get_cdn_matched_indices ''' dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"] num_gts = [len(t['labels']) for t in targets] device = targets[0]['labels'].device dn_match_indices = [] for i, num_gt in enumerate(num_gts): if num_gt > 0: gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device) gt_idx = gt_idx.tile(dn_num_group) assert len(dn_positive_idx[i]) == len(gt_idx) dn_match_indices.append((dn_positive_idx[i], gt_idx)) else: dn_match_indices.append((torch.zeros(0, dtype=torch.int64, device=device), \ torch.zeros(0, dtype=torch.int64, device=device))) return dn_match_indices