소스 검색

modify RT-DETR

yjh0410 1 년 전
부모
커밋
4cf3b2acbc

+ 7 - 4
engine.py

@@ -1145,8 +1145,8 @@ class RTDetrTrainer(object):
         os.makedirs(self.path_to_save, exist_ok=True)
 
         # ---------------------------- Hyperparameters refer to RTMDet ----------------------------
-        self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 0.05, 'lr0': 0.0002, 'backbone_lr_ratio': 0.1}
-        self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.1, 'warmup_iters': 1000} # no lr decay
+        self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 0.0001, 'lr0': 0.0001, 'backbone_lr_ratio': 0.1}
+        self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 1.0, 'warmup_iters': 2000} # no lr decay
         self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
 
         # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
@@ -1299,6 +1299,9 @@ class RTDetrTrainer(object):
                                 
             # To device
             images = images.to(self.device, non_blocking=True).float()
+            for tgt in targets:
+                tgt['boxes'] = tgt['boxes'].to(self.device)
+                tgt['labels'] = tgt['labels'].to(self.device)
 
             # Multi scale
             if self.args.multi_scale:
@@ -1321,7 +1324,7 @@ class RTDetrTrainer(object):
             with torch.cuda.amp.autocast(enabled=self.args.fp16):
                 outputs = model(images, targets)
                 # Compute loss
-                loss_dict = self.criterion(*outputs, targets)
+                loss_dict = self.criterion(outputs, targets)
                 losses = sum(loss_dict.values())
                 # Grad Accumulate
                 if self.grad_accumulate > 1:
@@ -1349,7 +1352,7 @@ class RTDetrTrainer(object):
                     self.model_ema.update(model)
 
             # Update log
-            metric_logger.update(**loss_dict_reduced)
+            metric_logger.update(loss=losses.item(), **loss_dict_reduced)
             metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
             metric_logger.update(grad_norm=grad_norm)
             metric_logger.update(size=img_size)

+ 39 - 66
models/detectors/rtdetr/basic_modules/dn_compoments.py

@@ -5,13 +5,13 @@ def inverse_sigmoid(x, eps=1e-5):
     x = x.clamp(min=0., max=1.)
     return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
 
-def bbox_cxcywh_to_xyxy(x):
+def box_cxcywh_to_xyxy(x):
     x_c, y_c, w, h = x.unbind(-1)
     b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
          (x_c + 0.5 * w), (y_c + 0.5 * h)]
     return torch.stack(b, dim=-1)
 
-def bbox_xyxy_to_cxcywh(x):
+def box_xyxy_to_cxcywh(x):
     x0, y0, x1, y1 = x.unbind(-1)
     b = [(x0 + x1) / 2, (y0 + y1) / 2,
          (x1 - x0), (y1 - y0)]
@@ -23,109 +23,83 @@ def get_contrastive_denoising_training_group(targets,
                                              class_embed,
                                              num_denoising=100,
                                              label_noise_ratio=0.5,
-                                             box_noise_scale=1.0):
+                                             box_noise_scale=1.0,):
     if num_denoising <= 0:
         return None, None, None, None
-    num_gts = [len(t["labels"]) for t in targets]
+
+    num_gts = [len(t['labels']) for t in targets]
+    device = targets[0]['labels'].device
+    
     max_gt_num = max(num_gts)
     if max_gt_num == 0:
         return None, None, None, None
 
     num_group = num_denoising // max_gt_num
     num_group = 1 if num_group == 0 else num_group
-
     # pad gt to max_num of a batch
-    bs = len(targets)
-    # [bs, max_gt_num]
-    input_query_class = torch.full([bs, max_gt_num], num_classes, device=class_embed.device).long()
-    # [bs, max_gt_num, 4]
-    input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=class_embed.device)
-    # [bs, max_gt_num]
-    pad_gt_mask = torch.zeros([bs, max_gt_num], device=class_embed.device)
+    bs = len(num_gts)
+
+    input_query_class = torch.full([bs, max_gt_num], num_classes, dtype=torch.int32, device=device)
+    input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=device)
+    pad_gt_mask = torch.zeros([bs, max_gt_num], dtype=torch.bool, device=device)
+
     for i in range(bs):
         num_gt = num_gts[i]
         if num_gt > 0:
-            input_query_class[i, :num_gt] = targets[i]["labels"]
-            input_query_bbox[i, :num_gt] = targets[i]["boxes"]
+            input_query_class[i, :num_gt] = targets[i]['labels']
+            input_query_bbox[i, :num_gt] = targets[i]['boxes']
             pad_gt_mask[i, :num_gt] = 1
-
     # each group has positive and negative queries.
-    input_query_class = input_query_class.repeat(1, 2 * num_group)  # [bs, 2*num_denoising], num_denoising = 2 * num_group * max_gt_num
-    input_query_bbox = input_query_bbox.repeat(1, 2 * num_group, 1) # [bs, 2*num_denoising, 4]
-    pad_gt_mask = pad_gt_mask.repeat(1, 2 * num_group)              # [bs, 2*num_denoising]
-
+    input_query_class = input_query_class.tile([1, 2 * num_group])
+    input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1])
+    pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group])
     # positive and negative mask
-    negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=class_embed.device)
+    negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=device)
     negative_gt_mask[:, max_gt_num:] = 1
-    negative_gt_mask = negative_gt_mask.repeat(1, num_group, 1)
+    negative_gt_mask = negative_gt_mask.tile([1, num_group, 1])
     positive_gt_mask = 1 - negative_gt_mask
-
     # contrastive denoising training positive index
     positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
     dn_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
     dn_positive_idx = torch.split(dn_positive_idx, [n * num_group for n in num_gts])
-    
     # total denoising queries
-    num_denoising = int(max_gt_num * 2 * num_group)  # num_denoising *= 2
+    num_denoising = int(max_gt_num * 2 * num_group)
 
     if label_noise_ratio > 0:
-        input_query_class = input_query_class.flatten()  # [bs * num_denoising]
-        pad_gt_mask = pad_gt_mask.flatten()
-        # half of bbox prob
-        mask = torch.rand(input_query_class.shape, device=class_embed.device) < (label_noise_ratio * 0.5)
-        chosen_idx = torch.nonzero(mask * pad_gt_mask).squeeze(-1)
+        mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
         # randomly put a new one here
-        new_label = torch.randint_like(
-            chosen_idx, 0, num_classes, dtype=input_query_class.dtype, device=class_embed.device) # [b * num_denoising]
-        # [bs * num_denoising]
-        input_query_class = torch.scatter(input_query_class, 0, chosen_idx, new_label)
-        # input_query_class.scatter_(chosen_idx, new_label)
-        # [bs * num_denoising] -> # [bs, num_denoising]
-        input_query_class = input_query_class.reshape(bs, num_denoising)
-        pad_gt_mask = pad_gt_mask.reshape(bs, num_denoising)
+        new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
+        input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
 
     if box_noise_scale > 0:
-        known_bbox = bbox_cxcywh_to_xyxy(input_query_bbox)
-        diff = torch.tile(input_query_bbox[..., 2:] * 0.5,
-                           [1, 1, 2]) * box_noise_scale
-
+        known_bbox = box_cxcywh_to_xyxy(input_query_bbox)
+        diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
         rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
-        rand_part = torch.rand(input_query_bbox.shape, device=class_embed.device)
-        rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (
-            1 - negative_gt_mask)
+        rand_part = torch.rand_like(input_query_bbox)
+        rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
         rand_part *= rand_sign
         known_bbox += rand_part * diff
-        known_bbox.clamp_(min=0.0, max=1.0)
-        input_query_bbox = bbox_xyxy_to_cxcywh(known_bbox)
+        known_bbox.clip_(min=0.0, max=1.0)
+        input_query_bbox = box_xyxy_to_cxcywh(known_bbox)
         input_query_bbox = inverse_sigmoid(input_query_bbox)
+    input_query_class = class_embed(input_query_class)
 
-    # [num_classes + 1, hidden_dim]
-    class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
-    # input_query_class = paddle.gather(class_embed, input_query_class.flatten(), axis=0)
-
-    # input_query_class: [bs, num_denoising] -> [bs*num_denoising, hidden_dim]
-    input_query_class = torch.torch.index_select(class_embed, 0, input_query_class.flatten())
-    # [bs*num_denoising, hidden_dim] -> [bs, num_denoising, hidden_dim]
-    input_query_class = input_query_class.reshape(bs, num_denoising, -1)
-    
     tgt_size = num_denoising + num_queries
-    attn_mask = torch.ones([tgt_size, tgt_size], device=class_embed.device) < 0
+    # attn_mask = torch.ones([tgt_size, tgt_size], device=device) < 0
+    attn_mask = torch.full([tgt_size, tgt_size], False, dtype=torch.bool, device=device)
     # match query cannot see the reconstruction
     attn_mask[num_denoising:, :num_denoising] = True
+    
     # reconstruct cannot see each other
     for i in range(num_group):
         if i == 0:
-            attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), max_gt_num *
-                      2 * (i + 1):num_denoising] = True
+            attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True
         if i == num_group - 1:
-            attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), :max_gt_num *
-                      i * 2] = True
+            attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * i * 2] = True
         else:
-            attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), max_gt_num *
-                      2 * (i + 1):num_denoising] = True
-            attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), :max_gt_num *
-                      2 * i] = True
-    attn_mask = ~attn_mask
+            attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True
+            attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * 2 * i] = True
+        
     dn_meta = {
         "dn_positive_idx": dn_positive_idx,
         "dn_num_group": num_group,
@@ -133,4 +107,3 @@ def get_contrastive_denoising_training_group(targets,
     }
 
     return input_query_class, input_query_bbox, attn_mask, dn_meta
-

+ 20 - 81
models/detectors/rtdetr/basic_modules/transformer.py

@@ -4,12 +4,11 @@ import copy
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-from torch.nn.init import constant_, xavier_uniform_, uniform_
 
 try:
-    from .basic import get_activation
+    from .basic import FFN
 except:
-    from  basic import get_activation
+    from  basic import FFN
 
 
 def get_clones(module, N):
@@ -23,38 +22,6 @@ def inverse_sigmoid(x, eps=1e-5):
     return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
 
 
-# ----------------- MLP modules -----------------
-class MLP(nn.Module):
-    def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
-        super().__init__()
-        self.num_layers = num_layers
-        h = [hidden_dim] * (num_layers - 1)
-        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([in_dim] + h, h + [out_dim]))
-
-    def forward(self, x):
-        for i, layer in enumerate(self.layers):
-            x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
-        return x
-
-class FFN(nn.Module):
-    def __init__(self, d_model=256, mlp_ratio=4.0, dropout=0., act_type='relu'):
-        super().__init__()
-        self.fpn_dim = round(d_model * mlp_ratio)
-        self.linear1 = nn.Linear(d_model, self.fpn_dim)
-        self.activation = get_activation(act_type)
-        self.dropout2 = nn.Dropout(dropout)
-        self.linear2 = nn.Linear(self.fpn_dim, d_model)
-        self.dropout3 = nn.Dropout(dropout)
-        self.norm = nn.LayerNorm(d_model)
-
-    def forward(self, src):
-        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
-        src = src + self.dropout3(src2)
-        src = self.norm(src)
-        
-        return src
-    
-
 # ----------------- Basic Transformer Ops -----------------
 def multi_scale_deformable_attn_pytorch(
     value: torch.Tensor,
@@ -137,7 +104,7 @@ class MSDeformableAttention(nn.Module):
         """
         Default initialization for Parameters of Module.
         """
-        constant_(self.sampling_offsets.weight.data, 0.0)
+        nn.init.constant_(self.sampling_offsets.weight.data, 0.0)
         thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
             2.0 * math.pi / self.num_heads
         )
@@ -151,12 +118,16 @@ class MSDeformableAttention(nn.Module):
             grid_init[:, :, i, :] *= i + 1
         with torch.no_grad():
             self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
-        constant_(self.attention_weights.weight.data, 0.0)
-        constant_(self.attention_weights.bias.data, 0.0)
-        xavier_uniform_(self.value_proj.weight.data)
-        constant_(self.value_proj.bias.data, 0.0)
-        xavier_uniform_(self.output_proj.weight.data)
-        constant_(self.output_proj.bias.data, 0.0)
+
+        # attention weight
+        nn.init.constant_(self.attention_weights.weight, 0.0)
+        nn.init.constant_(self.attention_weights.bias, 0.0)
+
+        # proj
+        nn.init.xavier_uniform_(self.value_proj.weight)
+        nn.init.constant_(self.value_proj.bias, 0.0)
+        nn.init.xavier_uniform_(self.output_proj.weight)
+        nn.init.constant_(self.output_proj.bias, 0.0)
 
     def forward(self,
                 query,
@@ -195,9 +166,8 @@ class MSDeformableAttention(nn.Module):
         # [bs, all_hw, num_head, nun_level*num_sample_point]
         attention_weights = self.attention_weights(query).reshape(
             [bs, num_query, self.num_heads, self.num_levels * self.num_points])
-        attention_weights = attention_weights.softmax(-1)
         # [bs, all_hw, num_head, nun_level, num_sample_point]
-        attention_weights = attention_weights.reshape(
+        attention_weights = attention_weights.softmax(-1).reshape(
             [bs, num_query, self.num_heads, self.num_levels, self.num_points])
 
         # [bs, num_query, num_heads, num_levels, num_points, 2]
@@ -210,7 +180,7 @@ class MSDeformableAttention(nn.Module):
                 [1, 1, 1, self.num_levels, 1, 2])
             sampling_locations = (
                 reference_points[:, :, None, :, None, :]
-                + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
+                + sampling_offsets / offset_normalizer
             )
         elif reference_points.shape[-1] == 4:
             sampling_locations = (
@@ -260,20 +230,9 @@ class TransformerEncoderLayer(nn.Module):
         # Feedforwaed Network
         self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
 
-        self._reset_parameters()
-
     def with_pos_embed(self, tensor, pos):
         return tensor if pos is None else tensor + pos
 
-    def _reset_parameters(self):
-        def linear_init_(module):
-            bound = 1 / math.sqrt(module.weight.shape[0])
-            uniform_(module.weight, -bound, bound)
-            if hasattr(module, "bias") and module.bias is not None:
-                uniform_(module.bias, -bound, bound)
-        linear_init_(self.ffn.linear1)
-        linear_init_(self.ffn.linear2)
-
     def forward(self, src, pos_embed):
         """
         Input:
@@ -395,7 +354,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
         self.act_type = act_type
         # ---------------- Network parameters ----------------
         ## Multi-head Self-Attn
-        self.self_attn  = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
+        self.self_attn  = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
         self.dropout1 = nn.Dropout(dropout)
         self.norm1 = nn.LayerNorm(d_model)
         ## CrossAttention
@@ -405,22 +364,9 @@ class DeformableTransformerDecoderLayer(nn.Module):
         ## FFN
         self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
 
-        self._reset_parameters()
-
     def with_pos_embed(self, tensor, pos):
         return tensor if pos is None else tensor + pos
 
-    def _reset_parameters(self):
-        def linear_init_(module):
-            bound = 1 / math.sqrt(module.weight.shape[0])
-            uniform_(module.weight, -bound, bound)
-            if hasattr(module, "bias") and module.bias is not None:
-                uniform_(module.bias, -bound, bound)
-        linear_init_(self.ffn.linear1)
-        linear_init_(self.ffn.linear2)
-        xavier_uniform_(self.ffn.linear1.weight)
-        xavier_uniform_(self.ffn.linear2.weight)
-
     def forward(self,
                 tgt,
                 reference_points,
@@ -431,12 +377,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
                 query_pos_embed=None):
         # ---------------- MSHA for Object Query -----------------
         q = k = self.with_pos_embed(tgt, query_pos_embed)
-        if attn_mask is not None:
-            attn_mask = torch.where(
-                attn_mask.bool(),
-                torch.zeros(attn_mask.shape, dtype=tgt.dtype, device=attn_mask.device),
-                torch.full(attn_mask.shape, float("-inf"), dtype=tgt.dtype, device=attn_mask.device))
-        tgt2 = self.self_attn(q, k, value=tgt)[0]
+        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=attn_mask)[0]
         tgt = tgt + self.dropout1(tgt2)
         tgt = self.norm1(tgt)
 
@@ -504,19 +445,17 @@ class DeformableTransformerDecoder(nn.Module):
                            memory_spatial_shapes, attn_mask,
                            memory_mask, query_pos_embed)
 
-            inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
-                ref_points_detach))
+            inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach))
 
             dec_out_logits.append(score_head[i](output))
             if i == 0:
                 dec_out_bboxes.append(inter_ref_bbox)
             else:
                 dec_out_bboxes.append(
-                    F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
-                        ref_points)))
+                    F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points)))
 
             ref_points = inter_ref_bbox
-            ref_points_detach = inter_ref_bbox.detach()
+            ref_points_detach = inter_ref_bbox.detach() if self.training else inter_ref_bbox
 
         return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)
 

+ 170 - 393
models/detectors/rtdetr/loss.py

@@ -1,424 +1,201 @@
-import math
+"""
+reference: 
+https://github.com/facebookresearch/detr/blob/main/models/detr.py
+
+by lyuwenyu
+"""
+
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
 try:
-    from .loss_utils import varifocal_loss_with_logits, sigmoid_focal_loss
-    from .loss_utils import box_cxcywh_to_xyxy, bbox_iou
+    from .loss_utils import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
     from .loss_utils import is_dist_avail_and_initialized, get_world_size
-    from .loss_utils import GIoULoss
     from .matcher import HungarianMatcher
 except:
-    from loss_utils import varifocal_loss_with_logits, sigmoid_focal_loss
-    from loss_utils import box_cxcywh_to_xyxy, bbox_iou
+    from loss_utils import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
     from loss_utils import is_dist_avail_and_initialized, get_world_size
-    from loss_utils import GIoULoss
     from matcher import HungarianMatcher
 
 
 # --------------- Criterion for RT-DETR ---------------
 def build_criterion(cfg, num_classes=80):
-    return Criterion(cfg, num_classes)
-
-class Criterion(object):
-    def __init__(self, cfg, num_classes=80):
-        self.matcher = HungarianMatcher(cfg['matcher_hpy']['cost_class'],
-                                        cfg['matcher_hpy']['cost_bbox'],
-                                        cfg['matcher_hpy']['cost_giou'],
-                                        alpha=0.25,
-                                        gamma=2.0)
-        self.loss = DINOLoss(num_classes   = num_classes,
-                                matcher    = self.matcher,
-                                aux_loss   = True,
-                                use_vfl    = cfg['use_vfl'],
-                                loss_coeff = cfg['loss_coeff'])
-
-    def __call__(self, dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta, targets=None):
-        assert targets is not None
-
-        gt_labels = [t['labels'].to(dec_out_bboxes.device) for t in targets]  # (List[torch.Tensor]) -> List[[N,]]
-        gt_boxes  = [t['boxes'].to(dec_out_bboxes.device)  for t in targets]  # (List[torch.Tensor]) -> List[[N, 4]]
-
-        if dn_meta is not None:
-            if isinstance(dn_meta, list):
-                dual_groups = len(dn_meta) - 1
-                dec_out_bboxes = torch.chunk(
-                    dec_out_bboxes, dual_groups + 1, dim=2)
-                dec_out_logits = torch.chunk(
-                    dec_out_logits, dual_groups + 1, dim=2)
-                enc_topk_bboxes = torch.chunk(
-                    enc_topk_bboxes, dual_groups + 1, dim=1)
-                enc_topk_logits = torch.splchunkt(
-                    enc_topk_logits, dual_groups + 1, dim=1)
-
-                loss = {}
-                for g_id in range(dual_groups + 1):
-                    if dn_meta[g_id] is not None:
-                        dn_out_bboxes_gid, dec_out_bboxes_gid = torch.split(
-                            dec_out_bboxes[g_id],
-                            dn_meta[g_id]['dn_num_split'],
-                            dim=2)
-                        dn_out_logits_gid, dec_out_logits_gid = torch.split(
-                            dec_out_logits[g_id],
-                            dn_meta[g_id]['dn_num_split'],
-                            dim=2)
-                    else:
-                        dn_out_bboxes_gid, dn_out_logits_gid = None, None
-                        dec_out_bboxes_gid = dec_out_bboxes[g_id]
-                        dec_out_logits_gid = dec_out_logits[g_id]
-                    out_bboxes_gid = torch.cat([
-                        enc_topk_bboxes[g_id].unsqueeze(0),
-                        dec_out_bboxes_gid
-                    ])
-                    out_logits_gid = torch.cat([
-                        enc_topk_logits[g_id].unsqueeze(0),
-                        dec_out_logits_gid
-                    ])
-                    loss_gid = self.loss(
-                        out_bboxes_gid,
-                        out_logits_gid,
-                        gt_boxes,
-                        gt_labels,
-                        dn_out_bboxes=dn_out_bboxes_gid,
-                        dn_out_logits=dn_out_logits_gid,
-                        dn_meta=dn_meta[g_id])
-                    # sum loss
-                    for key, value in loss_gid.items():
-                        loss.update({
-                            key: loss.get(key, torch.zeros([1], device=out_bboxes_gid.device)) + value
-                        })
-
-                # average across (dual_groups + 1)
-                for key, value in loss.items():
-                    loss.update({key: value / (dual_groups + 1)})
-                return loss
-            else:
-                dn_out_bboxes, dec_out_bboxes = torch.split(
-                    dec_out_bboxes, dn_meta['dn_num_split'], dim=2)
-                dn_out_logits, dec_out_logits = torch.split(
-                    dec_out_logits, dn_meta['dn_num_split'], dim=2)
-        else:
-            dn_out_bboxes, dn_out_logits = None, None
-
-        out_bboxes = torch.cat(
-            [enc_topk_bboxes.unsqueeze(0), dec_out_bboxes])
-        out_logits = torch.cat(
-            [enc_topk_logits.unsqueeze(0), dec_out_logits])
-
-        return self.loss(out_bboxes,
-                         out_logits,
-                         gt_boxes,
-                         gt_labels,
-                         dn_out_bboxes=dn_out_bboxes,
-                         dn_out_logits=dn_out_logits,
-                         dn_meta=dn_meta)
-
-
-# --------------- DETR series loss ---------------
-class DETRLoss(nn.Module):
-    """Modified Paddle DETRLoss class without mask loss."""
-    def __init__(self,
-                 num_classes=80,
-                 matcher='HungarianMatcher',
-                 aux_loss=True,
-                 use_vfl=False,
-                 loss_coeff={'class': 1,
-                             'bbox':  5,
-                             'giou':  2,},
-                 ):
-        super(DETRLoss, self).__init__()
+    matcher = HungarianMatcher(cfg['matcher_hpy'], alpha=0.25, gamma=2.0)
+    weight_dict = {'loss_cls':  cfg['loss_coeff']['class'],
+                   'loss_box':  cfg['loss_coeff']['bbox'],
+                   'loss_giou': cfg['loss_coeff']['giou']}
+    criterion = Criterion(matcher, weight_dict, num_classes=num_classes)
+
+    return criterion
+
+
+class Criterion(nn.Module):
+    """ This class computes the loss for DETR.
+    The process happens in two steps:
+        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+    """
+    def __init__(self, matcher, weight_dict, num_classes=80):
+        """ Create the criterion.
+        Parameters:
+            num_classes: number of object categories, omitting the special no-object category
+            matcher: module able to compute a matching between targets and proposals
+            weight_dict: dict containing as key the names of the losses and as values their relative weight.
+            eos_coef: relative classification weight applied to the no-object category
+            losses: list of all the losses to be applied. See get_loss for list of available losses.
+        """
+        super().__init__()
         self.num_classes = num_classes
         self.matcher = matcher
-        self.loss_coeff = loss_coeff
-        self.aux_loss = aux_loss
-        self.use_vfl = use_vfl
-        self.giou_loss = GIoULoss(reduction='none')
-
-    def _get_loss_class(self,
-                        logits,
-                        gt_class,
-                        match_indices,
-                        bg_index,
-                        num_gts,
-                        postfix="",
-                        iou_score=None):
-        # logits: [b, query, num_classes], gt_class: list[[n, 1]]
-        name_class = "loss_class" + postfix
-
-        target_label = torch.full(logits.shape[:2], bg_index, device=logits.device).long()
-        bs, num_query_objects = target_label.shape
-        num_gt = sum(len(a) for a in gt_class)
-        if num_gt > 0:
-            index, updates = self._get_index_updates(
-                num_query_objects, gt_class, match_indices)
-            target_label = target_label.reshape(-1, 1)
-            target_label[index] = updates.long()[:, None]
-            # target_label = paddle.scatter(target_label, index, updates.long())
-            target_label = target_label.reshape(bs, num_query_objects)
-
-        # one-hot label
-        target_label = F.one_hot(target_label, self.num_classes + 1)[..., :-1].float()
-        if iou_score is not None and self.use_vfl:
-            target_score = torch.zeros([bs, num_query_objects], device=logits.device)
-            if num_gt > 0:
-                target_score = target_score.reshape(-1, 1)
-                target_score[index] = iou_score.float()
-                # target_score = paddle.scatter(target_score, index, iou_score)
-            target_score = target_score.reshape(bs, num_query_objects, 1) * target_label
-            loss_cls = varifocal_loss_with_logits(logits,
-                                                  target_score,
-                                                  target_label,
-                                                  num_gts / num_query_objects)
-        else:
-            loss_cls = sigmoid_focal_loss(logits,
-                                          target_label,
-                                          num_gts / num_query_objects)
-
-        return {name_class: loss_cls * self.loss_coeff['class']}
-
-    def _get_loss_bbox(self, boxes, gt_bbox, match_indices, num_gts,
-                       postfix=""):
-        # boxes: [b, query, 4], gt_bbox: list[[n, 4]]
-        name_bbox = "loss_bbox" + postfix
-        name_giou = "loss_giou" + postfix
-
-        loss = dict()
-        if sum(len(a) for a in gt_bbox) == 0:
-            loss[name_bbox] = torch.as_tensor([0.], device=boxes.device)
-            loss[name_giou] = torch.as_tensor([0.], device=boxes.device)
-            return loss
-
-        # prepare positive samples
-        src_bbox, target_bbox = self._get_src_target_assign(boxes, gt_bbox, match_indices)
-
-        # Compute L1 loss
-        loss[name_bbox] = F.l1_loss(src_bbox, target_bbox, reduction='none')
-        loss[name_bbox] = loss[name_bbox].sum() / num_gts
-        loss[name_bbox] = self.loss_coeff['bbox'] * loss[name_bbox]
+        self.weight_dict = weight_dict
+        self.losses = ['labels', 'boxes']
+
+        self.alpha = 0.75  # For VFL
+        self.gamma = 2.0
+
+    def loss_labels(self, outputs, targets, indices, num_boxes):
+        "Compute variable focal loss"
+        assert 'pred_boxes' in outputs
+        idx = self._get_src_permutation_idx(indices)
+        # Compute IoU between pred and target
+        src_boxes = outputs['pred_boxes'][idx]
+        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
+        ious, _ = box_iou(box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes))
+        ious = torch.diag(ious).detach()
+
+        # One-hot class label
+        src_logits = outputs['pred_logits']
+        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
+        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
+                                    dtype=torch.int64, device=src_logits.device)
+        target_classes[idx] = target_classes_o
+        target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
+
+        # Iou-aware class label
+        target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype)
+        target_score_o[idx] = ious.to(target_score_o.dtype)
+        target_score = target_score_o.unsqueeze(-1) * target
+
+        # Compute VFL
+        pred_score = F.sigmoid(src_logits).detach()
+        weight = self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score
         
-        # Compute GIoU loss
-        loss[name_giou] = self.giou_loss(box_cxcywh_to_xyxy(src_bbox),
-                                         box_cxcywh_to_xyxy(target_bbox))
-        loss[name_giou] = loss[name_giou].sum() / num_gts
-        loss[name_giou] = self.loss_coeff['giou'] * loss[name_giou]
+        loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction='none')
+        loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
 
-        return loss
+        return {'loss_cls': loss}
 
-    def _get_loss_aux(self,
-                      boxes,
-                      logits,
-                      gt_bbox,
-                      gt_class,
-                      bg_index,
-                      num_gts,
-                      dn_match_indices=None,
-                      postfix=""):
-        loss_class = []
-        loss_bbox, loss_giou = [], []
-        if dn_match_indices is not None:
-            match_indices = dn_match_indices
-        for i, (aux_boxes, aux_logits) in enumerate(zip(boxes, logits)):
-            if dn_match_indices is None:
-                match_indices = self.matcher(
-                    aux_boxes,
-                    aux_logits,
-                    gt_bbox,
-                    gt_class,
-                    )
-            if self.use_vfl:
-                if sum(len(a) for a in gt_bbox) > 0:
-                    src_bbox, target_bbox = self._get_src_target_assign(
-                        aux_boxes.detach(), gt_bbox, match_indices)
-                    iou_score = bbox_iou(box_cxcywh_to_xyxy(src_bbox),
-                                         box_cxcywh_to_xyxy(target_bbox))
-                else:
-                    iou_score = None
-            else:
-                iou_score = None
-            loss_class.append(
-                self._get_loss_class(aux_logits, gt_class, match_indices,
-                                     bg_index, num_gts, postfix, iou_score)[
-                                         'loss_class' + postfix])
-            loss_ = self._get_loss_bbox(aux_boxes, gt_bbox, match_indices,
-                                        num_gts, postfix)
-            loss_bbox.append(loss_['loss_bbox' + postfix])
-            loss_giou.append(loss_['loss_giou' + postfix])
-
-        loss = {
-            "loss_class_aux" + postfix: sum(loss_class),
-            "loss_bbox_aux"  + postfix: sum(loss_bbox),
-            "loss_giou_aux"  + postfix: sum(loss_giou)
+    def loss_boxes(self, outputs, targets, indices, num_boxes):
+        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
+           targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
+           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
+        """
+        assert 'pred_boxes' in outputs
+        idx = self._get_src_permutation_idx(indices)
+        src_boxes = outputs['pred_boxes'][idx]
+        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
+
+        losses = {}
+
+        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
+        losses['loss_box'] = loss_bbox.sum() / num_boxes
+
+        loss_giou = 1 - torch.diag(generalized_box_iou(
+                box_cxcywh_to_xyxy(src_boxes),
+                box_cxcywh_to_xyxy(target_boxes)))
+        losses['loss_giou'] = loss_giou.sum() / num_boxes
+        return losses
+
+    def _get_src_permutation_idx(self, indices):
+        # permute predictions following indices
+        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+        src_idx = torch.cat([src for (src, _) in indices])
+        return batch_idx, src_idx
+
+    def _get_tgt_permutation_idx(self, indices):
+        # permute targets following indices
+        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+        return batch_idx, tgt_idx
+
+    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
+        loss_map = {
+            'boxes': self.loss_boxes,
+            'labels': self.loss_labels,
         }
-
-        return loss
-
-    def _get_index_updates(self, num_query_objects, target, match_indices):
-        batch_idx = torch.cat([
-            torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)
-        ])
-        src_idx = torch.cat([src for (src, _) in match_indices])
-        src_idx += (batch_idx * num_query_objects)
-        target_assign = torch.cat([
-            torch.gather(t, 0, dst.to(t.device)) for t, (_, dst) in zip(target, match_indices)
-        ])
-        return src_idx, target_assign
-
-    def _get_src_target_assign(self, src, target, match_indices):
-        src_assign = torch.cat([t[I] if len(I) > 0 else torch.zeros([0, t.shape[-1]], device=src.device)
-            for t, (I, _) in zip(src, match_indices)
-        ])
-
-        target_assign = torch.cat([t[J] if len(J) > 0 else torch.zeros([0, t.shape[-1]], device=src.device)
-            for t, (_, J) in zip(target, match_indices)
-        ])
-
-        return src_assign, target_assign
-
-    def _get_num_gts(self, targets):
-        num_gts = sum(len(a) for a in targets)
-        num_gts = torch.as_tensor([num_gts], device=targets[0].device).float()
-
-        if is_dist_avail_and_initialized():
-            torch.distributed.all_reduce(num_gts)
-        num_gts = torch.clamp(num_gts / get_world_size(), min=1).item()
-
-        return num_gts
-
-    def _get_prediction_loss(self,
-                             boxes,
-                             logits,
-                             gt_bbox,
-                             gt_class,
-                             postfix="",
-                             dn_match_indices=None,
-                             num_gts=1):
-        if dn_match_indices is None:
-            match_indices = self.matcher(boxes, logits, gt_bbox, gt_class)
-        else:
-            match_indices = dn_match_indices
-
-        if self.use_vfl:
-            if sum(len(a) for a in gt_bbox) > 0:
-                src_bbox, target_bbox = self._get_src_target_assign(
-                    boxes.detach(), gt_bbox, match_indices)
-                iou_score = bbox_iou(box_cxcywh_to_xyxy(src_bbox),
-                                     box_cxcywh_to_xyxy(target_bbox))
-            else:
-                iou_score = None
-        else:
-            iou_score = None
-
-        loss = dict()
-        loss.update(
-            self._get_loss_class(logits, gt_class, match_indices,
-                                 self.num_classes, num_gts, postfix, iou_score))
-        loss.update(
-            self._get_loss_bbox(boxes, gt_bbox, match_indices, num_gts,
-                                postfix))
-
-        return loss
-
-    def forward(self,
-                boxes,
-                logits,
-                gt_bbox,
-                gt_class,
-                postfix="",
-                **kwargs):
-        r"""
-        Args:
-            boxes (Tensor): [l, b, query, 4]
-            logits (Tensor): [l, b, query, num_classes]
-            gt_bbox (List(Tensor)): list[[n, 4]]
-            gt_class (List(Tensor)): list[[n, 1]]
-            masks (Tensor, optional): [l, b, query, h, w]
-            gt_mask (List(Tensor), optional): list[[n, H, W]]
-            postfix (str): postfix of loss name
+        assert loss in loss_map, f'do you really want to compute {loss} loss?'
+        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
+
+    def forward(self, outputs, targets):
+        """ This performs the loss computation.
+        Parameters:
+             outputs: dict of tensors, see the output specification of the model for the format
+             targets: list of dicts, such that len(targets) == batch_size.
+                      The expected keys in each dict depends on the losses applied, see each loss' doc
         """
+        outputs_without_aux = {k: v for k, v in outputs.items() if 'aux' not in k}
 
-        dn_match_indices = kwargs.get("dn_match_indices", None)
-        num_gts = kwargs.get("num_gts", None)
-        if num_gts is None:
-            num_gts = self._get_num_gts(gt_class)
-
-        total_loss = self._get_prediction_loss(
-            boxes[-1],
-            logits[-1],
-            gt_bbox,
-            gt_class,
-            postfix=postfix,
-            dn_match_indices=dn_match_indices,
-            num_gts=num_gts)
-
-        if self.aux_loss:
-            total_loss.update(
-                self._get_loss_aux(
-                    boxes[:-1],
-                    logits[:-1],
-                    gt_bbox,
-                    gt_class,
-                    self.num_classes,
-                    num_gts,
-                    dn_match_indices,
-                    postfix,
-                    ))
-
-        return total_loss
-
-class DINOLoss(DETRLoss):
-    def forward(self,
-                boxes,
-                logits,
-                gt_bbox,
-                gt_class,
-                postfix="",
-                dn_out_bboxes=None,
-                dn_out_logits=None,
-                dn_meta=None,
-                **kwargs):
-        num_gts = self._get_num_gts(gt_class)
-        total_loss = super(DINOLoss, self).forward(
-            boxes, logits, gt_bbox, gt_class, num_gts=num_gts)
+        # Retrieve the matching between the outputs of the last layer and the targets
+        indices = self.matcher(outputs_without_aux, targets)
 
-        if dn_meta is not None:
-            dn_positive_idx, dn_num_group = \
-                dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
-            assert len(gt_class) == len(dn_positive_idx)
-
-            # denoising match indices
-            dn_match_indices = self.get_dn_match_indices(
-                gt_class, dn_positive_idx, dn_num_group)
-
-            # compute denoising training loss
-            num_gts *= dn_num_group
-            dn_loss = super(DINOLoss, self).forward(
-                dn_out_bboxes,
-                dn_out_logits,
-                gt_bbox,
-                gt_class,
-                postfix="_dn",
-                dn_match_indices=dn_match_indices,
-                num_gts=num_gts)
-            total_loss.update(dn_loss)
-        else:
-            total_loss.update(
-                {k + '_dn': torch.as_tensor([0.])
-                 for k in total_loss.keys()})
-
-        return total_loss
+        # Compute the average number of target boxes accross all nodes, for normalization purposes
+        num_boxes = sum(len(t["labels"]) for t in targets)
+        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_boxes)
+        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
+
+        # Compute all the requested losses
+        losses = {}
+        for loss in self.losses:
+            l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
+            l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
+            losses.update(l_dict)
+
+        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+        if 'aux_outputs' in outputs:
+            for i, aux_outputs in enumerate(outputs['aux_outputs']):
+                indices = self.matcher(aux_outputs, targets)
+                for loss in self.losses:
+                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes)
+                    l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
+                    l_dict = {k + f'_aux_{i}': v for k, v in l_dict.items()}
+                    losses.update(l_dict)
+
+        # In case of cdn auxiliary losses. For rtdetr
+        if 'dn_aux_outputs' in outputs:
+            assert 'dn_meta' in outputs, ''
+            indices = self.get_cdn_matched_indices(outputs['dn_meta'], targets)
+            num_boxes = num_boxes * outputs['dn_meta']['dn_num_group']
+
+            for i, aux_outputs in enumerate(outputs['dn_aux_outputs']):
+                # indices = self.matcher(aux_outputs, targets)
+                for loss in self.losses:
+                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes)
+                    l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
+                    l_dict = {k + f'_dn_{i}': v for k, v in l_dict.items()}
+                    losses.update(l_dict)
+
+        return losses
 
     @staticmethod
-    def get_dn_match_indices(labels, dn_positive_idx, dn_num_group):
+    def get_cdn_matched_indices(dn_meta, targets):
+        '''get_cdn_matched_indices
+        '''
+        dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
+        num_gts = [len(t['labels']) for t in targets]
+        device = targets[0]['labels'].device
+        
         dn_match_indices = []
-        for i in range(len(labels)):
-            num_gt = len(labels[i])
+        for i, num_gt in enumerate(num_gts):
             if num_gt > 0:
-                gt_idx = torch.arange(num_gt).long()
-                gt_idx = gt_idx.tile([dn_num_group])
+                gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device)
+                gt_idx = gt_idx.tile(dn_num_group)
                 assert len(dn_positive_idx[i]) == len(gt_idx)
                 dn_match_indices.append((dn_positive_idx[i], gt_idx))
             else:
-                dn_match_indices.append((torch.zeros([0], device=labels[i].device).long(),
-                                         torch.zeros([0], device=labels[i].device).long()))
+                dn_match_indices.append((torch.zeros(0, dtype=torch.int64, device=device), \
+                    torch.zeros(0, dtype=torch.int64,  device=device)))
+        
         return dn_match_indices

+ 69 - 32
models/detectors/rtdetr/matcher.py

@@ -4,50 +4,87 @@ import torch.nn.functional as F
 from scipy.optimize import linear_sum_assignment
 
 try:
-    from .loss_utils import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh, generalized_box_iou
+    from .loss_utils import box_cxcywh_to_xyxy, generalized_box_iou
 except:
-    from  loss_utils import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh, generalized_box_iou
+    from  loss_utils import box_cxcywh_to_xyxy, generalized_box_iou
 
 
 class HungarianMatcher(nn.Module):
-    def __init__(self, cost_class, cost_bbox, cost_giou, alpha=0.25, gamma=2.0):
+    """This class computes an assignment between the targets and the predictions of the network
+
+    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
+    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
+    while the others are un-matched (and thus treated as non-objects).
+    """
+
+    __share__ = ['use_focal_loss', ]
+
+    def __init__(self, weight_dict, alpha=0.25, gamma=2.0):
+        """Creates the matcher
+
+        Params:
+            cost_class: This is the relative weight of the classification error in the matching cost
+            cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
+            cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
+        """
         super().__init__()
-        self.cost_class = cost_class
-        self.cost_bbox = cost_bbox
-        self.cost_giou = cost_giou
+        self.cost_class = weight_dict['cost_class']
+        self.cost_bbox = weight_dict['cost_bbox']
+        self.cost_giou = weight_dict['cost_giou']
+
         self.alpha = alpha
         self.gamma = gamma
 
+        assert self.cost_class != 0 or self.cost_bbox != 0 or self.cost_giou != 0, "all costs cant be 0"
+
     @torch.no_grad()
-    def forward(self, pred_boxes, pred_logits, gt_boxes, gt_labels):
-        bs, num_queries = pred_logits.shape[:2]
-        # [B, Nq, C] -> [BNq, C]
-        out_prob = pred_logits.flatten(0, 1).sigmoid()
-        out_bbox = pred_boxes.flatten(0, 1)
-
-        # List[B, M, C] -> [BM, C]
-        tgt_ids = torch.cat(gt_labels).long()
-        tgt_bbox = torch.cat(gt_boxes).float()
-
-        # -------------------- Classification cost --------------------
-        neg_cost_class = (1 - self.alpha) * (out_prob ** self.gamma) * (-(1 - out_prob + 1e-8).log())
-        pos_cost_class = self.alpha * ((1 - out_prob) ** self.gamma) * (-(out_prob + 1e-8).log())
-        cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
-
-        # -------------------- Regression cost --------------------
-        ## L1 cost: [Nq, M]
-        cost_bbox = torch.cdist(out_bbox, tgt_bbox.to(out_bbox.device), p=1)
-        ## GIoU cost: Nq, M]
-        cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),
-                                         box_cxcywh_to_xyxy(tgt_bbox).to(out_bbox.device))
-
-        # Final cost: [B, Nq, M]
+    def forward(self, outputs, targets):
+        """ Performs the matching
+
+        Params:
+            outputs: This is a dict that contains at least these entries:
+                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
+                 "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
+
+            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
+                 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
+                           objects in the target) containing the class labels
+                 "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
+
+        Returns:
+            A list of size batch_size, containing tuples of (index_i, index_j) where:
+                - index_i is the indices of the selected predictions (in order)
+                - index_j is the indices of the corresponding selected targets (in order)
+            For each batch element, it holds:
+                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+        """
+        bs, num_queries = outputs["pred_logits"].shape[:2]
+
+        # We flatten to compute the cost matrices in a batch
+        out_prob = F.sigmoid(outputs["pred_logits"].flatten(0, 1))
+        out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]
+
+        # Also concat the target labels and boxes
+        tgt_ids = torch.cat([v["labels"] for v in targets])
+        tgt_bbox = torch.cat([v["boxes"] for v in targets])
+
+        # Compute the classification cost
+        out_prob = out_prob[:, tgt_ids]
+        neg_cost_class = (1 - self.alpha) * (out_prob**self.gamma) * (-(1 - out_prob + 1e-8).log())
+        pos_cost_class = self.alpha * ((1 - out_prob)**self.gamma) * (-(out_prob + 1e-8).log())
+        cost_class = pos_cost_class - neg_cost_class        
+
+        # Compute the L1 cost between boxes
+        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
+
+        # Compute the giou cost betwen boxes
+        cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
+        
+        # Final cost matrix
         C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
         C = C.view(bs, num_queries, -1).cpu()
 
-        # Label assignment
-        sizes = [len(t) for t in gt_boxes]
+        sizes = [len(v["boxes"]) for v in targets]
         indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
 
         return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
-    

+ 11 - 14
models/detectors/rtdetr/rtdetr.py

@@ -6,7 +6,7 @@ try:
     from .rtdetr_encoder import build_image_encoder
     from .rtdetr_decoder import build_transformer
 except:
-    from .basic_modules.basic import multiclass_nms
+    from  basic_modules.basic import multiclass_nms
     from  rtdetr_encoder import build_image_encoder
     from  rtdetr_decoder import build_transformer
 
@@ -114,15 +114,12 @@ class RT_DETR(nn.Module):
         pyramid_feats = self.image_encoder(x)
 
         # ----------- Transformer -----------
-        transformer_outputs = self.detect_decoder(pyramid_feats, targets)
+        outputs = self.detect_decoder(pyramid_feats, targets)
 
-        if self.training:
-            return transformer_outputs
-        else:
+        if not self.training:
             img_h, img_w = x.shape[2:]
-            pred_boxes, pred_logits = transformer_outputs[0], transformer_outputs[1]
-            box_pred = pred_boxes[-1]
-            cls_pred = pred_logits[-1]
+            box_pred = outputs["pred_boxes"]
+            cls_pred = outputs["pred_logits"]
 
             # rescale bbox
             box_pred[..., [0, 2]] *= img_h
@@ -137,7 +134,7 @@ class RT_DETR(nn.Module):
                 "bboxes": bboxes,
             }
 
-            return outputs
+        return outputs
         
 
 if __name__ == '__main__':
@@ -202,15 +199,15 @@ if __name__ == '__main__':
         }
     bs = 1
     # Create a batch of images & targets
-    image = torch.randn(bs, 3, 640, 640)
+    image = torch.randn(bs, 3, 640, 640).cuda()
     targets = [{
-        'labels': torch.tensor([2, 4, 5, 8]).long(),
-        'boxes':  torch.tensor([[0, 0, 10, 10], [12, 23, 56, 70], [0, 10, 20, 30], [50, 60, 55, 150]]).float() / 640.
+        'labels': torch.tensor([2, 4, 5, 8]).long().cuda(),
+        'boxes':  torch.tensor([[0, 0, 10, 10], [12, 23, 56, 70], [0, 10, 20, 30], [50, 60, 55, 150]]).float().cuda() / 640.
     }] * bs
 
     # Create model
     model = RT_DETR(cfg, num_classes=20)
-    model.train()
+    model.train().cuda()
 
     # Create criterion
     criterion = build_criterion(cfg, num_classes=20)
@@ -222,7 +219,7 @@ if __name__ == '__main__':
     print('Infer time: ', t1 - t0)
 
     # Compute loss
-    loss = criterion(*outputs, targets)
+    loss = criterion(outputs, targets)
     for k in loss.keys():
         print("{} : {}".format(k, loss[k].item()))
 

+ 56 - 46
models/detectors/rtdetr/rtdetr_decoder.py

@@ -62,7 +62,8 @@ class RTDETRTransformer(nn.Module):
                  num_denoising       :int  = 100,
                  label_noise_ratio   :float = 0.5,
                  box_noise_scale     :float = 1.0,
-                 learnt_init_query   :bool  = True,
+                 learnt_init_query   :bool  = False,
+                 aux_loss            :bool  = True
                  ):
         super().__init__()
         # --------------- Basic setting ---------------
@@ -73,6 +74,7 @@ class RTDETRTransformer(nn.Module):
         self.pos_embed_type = pos_embed_type
         self.num_classes = num_classes
         self.eps = 1e-2
+        self.aux_loss = aux_loss
         ## Transformer parameters
         self.num_heads  = num_heads
         self.num_layers = num_layers
@@ -132,39 +134,37 @@ class RTDETRTransformer(nn.Module):
         self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
 
         ## Denoising part
-        self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim)
+        if num_denoising > 0: 
+            self.denoising_class_embed = nn.Embedding(num_classes+1, hidden_dim, padding_idx=num_classes)
 
         self._reset_parameters()
 
     def _reset_parameters(self):
-        def linear_init_(module):
-            bound = 1 / math.sqrt(module.weight.shape[0])
-            uniform_(module.weight, -bound, bound)
-            if hasattr(module, "bias") and module.bias is not None:
-                uniform_(module.bias, -bound, bound)
-
         # class and bbox head init
         prior_prob = 0.01
         cls_bias_init = float(-math.log((1 - prior_prob) / prior_prob))
-        linear_init_(self.enc_class_head)
-        constant_(self.enc_class_head.bias, cls_bias_init)
-        constant_(self.enc_bbox_head.layers[-1].weight, 0.)
-        constant_(self.enc_bbox_head.layers[-1].bias, 0.)
+
+        nn.init.constant_(self.enc_class_head.bias, cls_bias_init)
+        nn.init.constant_(self.enc_bbox_head.layers[-1].weight, 0.)
+        nn.init.constant_(self.enc_bbox_head.layers[-1].bias, 0.)
         for cls_, reg_ in zip(self.dec_class_head, self.dec_bbox_head):
-            linear_init_(cls_)
-            constant_(cls_.bias, cls_bias_init)
-            constant_(reg_.layers[-1].weight, 0.)
-            constant_(reg_.layers[-1].bias, 0.)
+            nn.init.constant_(cls_.bias, cls_bias_init)
+            nn.init.constant_(reg_.layers[-1].weight, 0.)
+            nn.init.constant_(reg_.layers[-1].bias, 0.)
 
-        linear_init_(self.enc_output[0])
-        xavier_uniform_(self.enc_output[0].weight)
+        nn.init.xavier_uniform_(self.enc_output[0].weight)
         if self.learnt_init_query:
-            xavier_uniform_(self.tgt_embed.weight)
-        xavier_uniform_(self.query_pos_head.layers[0].weight)
-        xavier_uniform_(self.query_pos_head.layers[1].weight)
-        for l in self.input_proj_layers:
-            xavier_uniform_(l.conv.weight)
-        normal_(self.denoising_class_embed.weight)
+            nn.init.xavier_uniform_(self.tgt_embed.weight)
+        nn.init.xavier_uniform_(self.query_pos_head.layers[0].weight)
+        nn.init.xavier_uniform_(self.query_pos_head.layers[1].weight)
+
+    @torch.jit.unused
+    def _set_aux_loss(self, outputs_class, outputs_coord):
+        # this is a workaround to make torchscript happy, as torchscript
+        # doesn't support dictionary with non-homogeneous values, such
+        # as a dict having both a Tensor and a list.
+        return [{'pred_logits': a, 'pred_boxes': b}
+                for a, b in zip(outputs_class, outputs_coord)]
 
     def generate_anchors(self, spatial_shapes, grid_size=0.05):
         anchors = []
@@ -183,7 +183,7 @@ class RTDETRTransformer(nn.Module):
         valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
         anchors = torch.log(anchors / (1 - anchors))
         # Equal to operation: anchors = torch.masked_fill(anchors, ~valid_mask, torch.as_tensor(float("inf")))
-        anchors = torch.where(valid_mask, anchors, torch.as_tensor(float("inf")))
+        anchors = torch.where(valid_mask, anchors, torch.inf)
         
         return anchors, valid_mask
     
@@ -239,9 +239,7 @@ class RTDETRTransformer(nn.Module):
 
         if denoising_bbox_unact is not None:
             reference_points_unact = torch.cat(
-                [denoising_bbox_unact, reference_points_unact], 1)
-        if self.training:
-            reference_points_unact = reference_points_unact.detach()
+                [denoising_bbox_unact, reference_points_unact], dim=1)
 
         # Extract region features
         if self.learnt_init_query:
@@ -250,27 +248,27 @@ class RTDETRTransformer(nn.Module):
         else:
             # [num_queries, c] -> [b, num_queries, c]
             target = torch.gather(output_memory, 1, topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
-            if self.training:
-                target = target.detach()
+            target = target.detach()
+        
         if denoising_class is not None:
             target = torch.cat([denoising_class, target], dim=1)
 
-        return target, reference_points_unact, enc_topk_bboxes, enc_topk_logits
+        return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits
     
     def forward(self, feats, targets=None):
         # input projection and embedding
         memory, spatial_shapes, _ = self.get_encoder_input(feats)
 
         # prepare denoising training
-        if self.training:
+        if self.training and self.num_denoising > 0:
             denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
-                get_contrastive_denoising_training_group(targets,
-                                                         self.num_classes,
-                                                         self.num_queries,
-                                                         self.denoising_class_embed.weight,
-                                                         self.num_denoising,
-                                                         self.label_noise_ratio,
-                                                         self.box_noise_scale)
+                get_contrastive_denoising_training_group(targets, \
+                                                         self.num_classes, 
+                                                         self.num_queries, 
+                                                         self.denoising_class_embed, 
+                                                         num_denoising=self.num_denoising, 
+                                                         label_noise_ratio=self.label_noise_ratio, 
+                                                         box_noise_scale=self.box_noise_scale, )
         else:
             denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
 
@@ -287,8 +285,22 @@ class RTDETRTransformer(nn.Module):
                                               self.dec_class_head,
                                               self.query_pos_head,
                                               attn_mask)
-        
-        return out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta
+
+        if self.training and dn_meta is not None:
+            dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta['dn_num_split'], dim=2)
+            dn_out_logits, out_logits = torch.split(out_logits, dn_meta['dn_num_split'], dim=2)
+
+        out = {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]}
+
+        if self.training and self.aux_loss:
+            out['aux_outputs'] = self._set_aux_loss(out_logits[:-1], out_bboxes[:-1])
+            out['aux_outputs'].extend(self._set_aux_loss([enc_topk_logits], [enc_topk_bboxes]))
+            
+            if self.training and dn_meta is not None:
+                out['dn_aux_outputs'] = self._set_aux_loss(dn_out_logits, dn_out_bboxes)
+                out['dn_meta'] = dn_meta
+
+        return out
 
 
 # ----------------- Dencoder for Segmentation task -----------------
@@ -349,13 +361,11 @@ if __name__ == '__main__':
 
     t0 = time.time()
     outputs = model(pyramid_feats, targets)
-    out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta = outputs
     t1 = time.time()
     print('Time: ', t1 - t0)
-    print(out_bboxes.shape)
-    print(out_logits.shape)
-    print(enc_topk_bboxes.shape)
-    print(enc_topk_logits.shape)
+
+    print(outputs["pred_logits"].shape)
+    print(outputs["pred_boxes"].shape)
 
     print('==============================')
     model.eval()