|
@@ -0,0 +1,129 @@
|
|
|
|
|
+"""
|
|
|
|
|
+reference:
|
|
|
|
|
+https://github.com/facebookresearch/detr/blob/main/models/detr.py
|
|
|
|
|
+
|
|
|
|
|
+by lyuwenyu
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+import torch
|
|
|
|
|
+import torch.nn as nn
|
|
|
|
|
+import torch.nn.functional as F
|
|
|
|
|
+
|
|
|
|
|
+from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
|
|
|
|
|
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
|
|
|
|
|
+from .matcher import HungarianMatcher
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------- Criterion for DETR ---------------
|
|
|
|
|
+class SetCriterion(nn.Module):
|
|
|
|
|
+ def __init__(self, cfg):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.num_classes = cfg.num_classes
|
|
|
|
|
+ self.losses = ['labels', 'boxes']
|
|
|
|
|
+ # -------- Loss weights --------
|
|
|
|
|
+ self.weight_dict = {'loss_cls': cfg.loss_cls,
|
|
|
|
|
+ 'loss_box': cfg.loss_box,
|
|
|
|
|
+ 'loss_giou': cfg.loss_giou}
|
|
|
|
|
+ for i in range(cfg.num_dec_layers - 1):
|
|
|
|
|
+ self.weight_dict.update({k + f'_aux_{i}': v for k, v in self.weight_dict.items()})
|
|
|
|
|
+ # -------- Matcher --------
|
|
|
|
|
+ self.matcher = HungarianMatcher(cfg.cost_class, cfg.cost_bbox, cfg.cost_giou)
|
|
|
|
|
+
|
|
|
|
|
+ def loss_labels(self, outputs, targets, indices, num_boxes):
|
|
|
|
|
+ assert 'pred_logits' in outputs
|
|
|
|
|
+ src_logits = outputs['pred_logits']
|
|
|
|
|
+
|
|
|
|
|
+ idx = self._get_src_permutation_idx(indices)
|
|
|
|
|
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
|
|
|
|
|
+ target_classes = torch.full(src_logits.shape[:2], self.num_classes,
|
|
|
|
|
+ dtype=torch.int64, device=src_logits.device)
|
|
|
|
|
+ target_classes[idx] = target_classes_o
|
|
|
|
|
+
|
|
|
|
|
+ loss_cls = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
|
|
|
|
|
+
|
|
|
|
|
+ return {'loss_cls': loss_cls.sum() / num_boxes}
|
|
|
|
|
+
|
|
|
|
|
+ def loss_boxes(self, outputs, targets, indices, num_boxes):
|
|
|
|
|
+ assert 'pred_boxes' in outputs
|
|
|
|
|
+ idx = self._get_src_permutation_idx(indices)
|
|
|
|
|
+ src_boxes = outputs['pred_boxes'][idx]
|
|
|
|
|
+ target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
|
|
|
|
+
|
|
|
|
|
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
|
|
|
|
|
+ loss_giou = 1 - torch.diag(generalized_box_iou(box_cxcywh_to_xyxy(src_boxes),
|
|
|
|
|
+ box_cxcywh_to_xyxy(target_boxes)))
|
|
|
|
|
+
|
|
|
|
|
+ return {'loss_box': loss_bbox.sum() / num_boxes,
|
|
|
|
|
+ 'loss_giou': loss_giou.sum() / num_boxes}
|
|
|
|
|
+
|
|
|
|
|
+ def _get_src_permutation_idx(self, indices):
|
|
|
|
|
+ # permute predictions following indices
|
|
|
|
|
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
|
|
|
|
+ src_idx = torch.cat([src for (src, _) in indices])
|
|
|
|
|
+
|
|
|
|
|
+ return batch_idx, src_idx
|
|
|
|
|
+
|
|
|
|
|
+ def _get_tgt_permutation_idx(self, indices):
|
|
|
|
|
+ # permute targets following indices
|
|
|
|
|
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
|
|
|
|
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
|
|
|
|
+
|
|
|
|
|
+ return batch_idx, tgt_idx
|
|
|
|
|
+
|
|
|
|
|
+ def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
|
|
|
|
|
+ loss_map = {
|
|
|
|
|
+ 'boxes': self.loss_boxes,
|
|
|
|
|
+ 'labels': self.loss_labels,
|
|
|
|
|
+ }
|
|
|
|
|
+ return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, outputs, targets):
|
|
|
|
|
+ outputs_without_aux = {k: v for k, v in outputs.items() if 'aux' not in k}
|
|
|
|
|
+
|
|
|
|
|
+ # Retrieve the matching between the outputs of the last layer and the targets
|
|
|
|
|
+ indices = self.matcher(outputs_without_aux, targets)
|
|
|
|
|
+
|
|
|
|
|
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
|
|
|
|
|
+ num_boxes = sum(len(t["labels"]) for t in targets)
|
|
|
|
|
+ num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
|
|
|
|
+ if is_dist_avail_and_initialized():
|
|
|
|
|
+ torch.distributed.all_reduce(num_boxes)
|
|
|
|
|
+ num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
|
|
|
|
|
+
|
|
|
|
|
+ # Compute all the requested losses
|
|
|
|
|
+ losses = {}
|
|
|
|
|
+ for loss in self.losses:
|
|
|
|
|
+ l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
|
|
|
|
|
+ losses.update(l_dict)
|
|
|
|
|
+
|
|
|
|
|
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
|
|
|
|
+ if 'aux_outputs' in outputs:
|
|
|
|
|
+ for i, aux_outputs in enumerate(outputs['aux_outputs']):
|
|
|
|
|
+ indices = self.matcher(aux_outputs, targets)
|
|
|
|
|
+ for loss in self.losses:
|
|
|
|
|
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes)
|
|
|
|
|
+ l_dict = {k + f'_aux_{i}': v for k, v in l_dict.items()}
|
|
|
|
|
+ losses.update(l_dict)
|
|
|
|
|
+
|
|
|
|
|
+ return losses
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def get_cdn_matched_indices(dn_meta, targets):
|
|
|
|
|
+ '''get_cdn_matched_indices
|
|
|
|
|
+ '''
|
|
|
|
|
+ dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
|
|
|
|
|
+ num_gts = [len(t['labels']) for t in targets]
|
|
|
|
|
+ device = targets[0]['labels'].device
|
|
|
|
|
+
|
|
|
|
|
+ dn_match_indices = []
|
|
|
|
|
+ for i, num_gt in enumerate(num_gts):
|
|
|
|
|
+ if num_gt > 0:
|
|
|
|
|
+ gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device)
|
|
|
|
|
+ gt_idx = gt_idx.tile(dn_num_group)
|
|
|
|
|
+ assert len(dn_positive_idx[i]) == len(gt_idx)
|
|
|
|
|
+ dn_match_indices.append((dn_positive_idx[i], gt_idx))
|
|
|
|
|
+ else:
|
|
|
|
|
+ dn_match_indices.append((torch.zeros(0, dtype=torch.int64, device=device), \
|
|
|
|
|
+ torch.zeros(0, dtype=torch.int64, device=device)))
|
|
|
|
|
+
|
|
|
|
|
+ return dn_match_indices
|