""" reference: https://github.com/facebookresearch/detr/blob/main/models/detr.py by lyuwenyu """ import torch import torch.nn as nn import torch.nn.functional as F from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized from .matcher import HungarianMatcher # --------------- Criterion for DETR --------------- class SetCriterion(nn.Module): def __init__(self, cfg): super().__init__() self.num_classes = cfg.num_classes self.losses = ['labels', 'boxes'] self.eos_coef = 0.1 # -------- Loss weights -------- self.weight_dict = {'loss_cls': cfg.loss_cls, 'loss_box': cfg.loss_box, 'loss_giou': cfg.loss_giou} for i in range(cfg.num_dec_layers - 1): self.weight_dict.update({k + f'_aux_{i}': v for k, v in self.weight_dict.items()}) empty_weight = torch.ones(self.num_classes + 1) empty_weight[-1] = self.eos_coef self.register_buffer('empty_weight', empty_weight) # -------- Matcher -------- matcher_hpy = cfg.matcher_hpy self.matcher = HungarianMatcher(matcher_hpy['cost_class'], matcher_hpy['cost_bbox'], matcher_hpy['cost_giou']) def loss_labels(self, outputs, targets, indices, num_boxes): assert 'pred_logits' in outputs src_logits = outputs['pred_logits'] target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) loss_cls = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) return {'loss_cls': loss_cls.sum() / num_boxes} def loss_boxes(self, outputs, targets, indices, num_boxes): assert 'pred_boxes' in outputs idx = self._get_src_permutation_idx(indices) src_boxes = outputs['pred_boxes'][idx] target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') loss_giou = 1 - torch.diag(generalized_box_iou(box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes))) return {'loss_box': loss_bbox.sum() / num_boxes, 'loss_giou': loss_giou.sum() / num_boxes} def _get_src_permutation_idx(self, indices): # permute predictions following indices batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def _get_tgt_permutation_idx(self, indices): # permute targets following indices batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): loss_map = { 'boxes': self.loss_boxes, 'labels': self.loss_labels, } return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) def forward(self, outputs, targets): outputs_without_aux = {k: v for k, v in outputs.items() if 'aux' not in k} # Retrieve the matching between the outputs of the last layer and the targets indices = self.matcher(outputs_without_aux, targets) # Compute the average number of target boxes accross all nodes, for normalization purposes num_boxes = sum(len(t["labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_boxes) num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() # Compute all the requested losses loss_dict = {} for loss in self.losses: l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes) l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} loss_dict.update(l_dict) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if 'aux_outputs' in outputs: for i, aux_outputs in enumerate(outputs['aux_outputs']): indices = self.matcher(aux_outputs, targets) for loss in self.losses: l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes) l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} l_dict = {k + f'_aux_{i}': v for k, v in l_dict.items()} loss_dict.update(l_dict) # Total loss loss_dict["losses"] = sum(loss_dict.values()) return loss_dict