| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- """
- reference:
- https://github.com/facebookresearch/detr/blob/main/models/detr.py
- by lyuwenyu
- """
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from .loss_utils import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
- from .loss_utils import is_dist_avail_and_initialized, get_world_size
- from .matcher import HungarianMatcher
- # --------------- Criterion for RT-DETR ---------------
- class SetCriterion(nn.Module):
- def __init__(self, cfg):
- super().__init__()
- self.num_classes = cfg.num_classes
- self.losses = ['labels', 'boxes']
- self.alpha = 0.75 # For VFL
- self.gamma = 2.0
- self.matcher = HungarianMatcher(cfg.cost_class, cfg.cost_bbox, cfg.cost_giou, alpha=0.25, gamma=2.0)
- self.weight_dict = {'loss_cls': cfg.loss_cls,
- 'loss_box': cfg.loss_box,
- 'loss_giou': cfg.loss_giou}
- def loss_labels(self, outputs, targets, indices, num_boxes):
- "Compute variable focal loss"
- assert 'pred_boxes' in outputs
- idx = self._get_src_permutation_idx(indices)
- # Compute IoU between pred and target
- src_boxes = outputs['pred_boxes'][idx]
- target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
- ious, _ = box_iou(box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes))
- ious = torch.diag(ious).detach()
- # One-hot class label
- src_logits = outputs['pred_logits']
- target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
- target_classes = torch.full(src_logits.shape[:2], self.num_classes,
- dtype=torch.int64, device=src_logits.device)
- target_classes[idx] = target_classes_o
- target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
- # Iou-aware class label
- target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype)
- target_score_o[idx] = ious.to(target_score_o.dtype)
- target_score = target_score_o.unsqueeze(-1) * target
- # Compute VFL
- pred_score = F.sigmoid(src_logits).detach()
- weight = self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score
-
- loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction='none')
- loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
- return {'loss_cls': loss}
- def loss_boxes(self, outputs, targets, indices, num_boxes):
- """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
- targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
- The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
- """
- assert 'pred_boxes' in outputs
- idx = self._get_src_permutation_idx(indices)
- src_boxes = outputs['pred_boxes'][idx]
- target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
- losses = {}
- loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
- losses['loss_box'] = loss_bbox.sum() / num_boxes
- loss_giou = 1 - torch.diag(generalized_box_iou(
- box_cxcywh_to_xyxy(src_boxes),
- box_cxcywh_to_xyxy(target_boxes)))
- losses['loss_giou'] = loss_giou.sum() / num_boxes
- return losses
- def _get_src_permutation_idx(self, indices):
- # permute predictions following indices
- batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
- src_idx = torch.cat([src for (src, _) in indices])
- return batch_idx, src_idx
- def _get_tgt_permutation_idx(self, indices):
- # permute targets following indices
- batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
- tgt_idx = torch.cat([tgt for (_, tgt) in indices])
- return batch_idx, tgt_idx
- def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
- loss_map = {
- 'boxes': self.loss_boxes,
- 'labels': self.loss_labels,
- }
- assert loss in loss_map, f'do you really want to compute {loss} loss?'
- return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
- def forward(self, outputs, targets):
- outputs_without_aux = {k: v for k, v in outputs.items() if 'aux' not in k}
- # Retrieve the matching between the outputs of the last layer and the targets
- indices = self.matcher(outputs_without_aux, targets)
- # Compute the average number of target boxes accross all nodes, for normalization purposes
- num_boxes = sum(len(t["labels"]) for t in targets)
- num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
- if is_dist_avail_and_initialized():
- torch.distributed.all_reduce(num_boxes)
- num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
- # Compute all the requested losses
- losses = {}
- for loss in self.losses:
- l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
- l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
- losses.update(l_dict)
- # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
- if 'aux_outputs' in outputs:
- for i, aux_outputs in enumerate(outputs['aux_outputs']):
- indices = self.matcher(aux_outputs, targets)
- for loss in self.losses:
- l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes)
- l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
- l_dict = {k + f'_aux_{i}': v for k, v in l_dict.items()}
- losses.update(l_dict)
- # In case of cdn auxiliary losses. For rtdetr
- if 'dn_aux_outputs' in outputs:
- assert 'dn_meta' in outputs, ''
- indices = self.get_cdn_matched_indices(outputs['dn_meta'], targets)
- num_boxes = num_boxes * outputs['dn_meta']['dn_num_group']
- for i, aux_outputs in enumerate(outputs['dn_aux_outputs']):
- # indices = self.matcher(aux_outputs, targets)
- for loss in self.losses:
- l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes)
- l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
- l_dict = {k + f'_dn_{i}': v for k, v in l_dict.items()}
- losses.update(l_dict)
- return losses
- @staticmethod
- def get_cdn_matched_indices(dn_meta, targets):
- '''get_cdn_matched_indices
- '''
- dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
- num_gts = [len(t['labels']) for t in targets]
- device = targets[0]['labels'].device
-
- dn_match_indices = []
- for i, num_gt in enumerate(num_gts):
- if num_gt > 0:
- gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device)
- gt_idx = gt_idx.tile(dn_num_group)
- assert len(dn_positive_idx[i]) == len(gt_idx)
- dn_match_indices.append((dn_positive_idx[i], gt_idx))
- else:
- dn_match_indices.append((torch.zeros(0, dtype=torch.int64, device=device), \
- torch.zeros(0, dtype=torch.int64, device=device)))
-
- return dn_match_indices
|