Browse Source

add RTDETR-Transformer Decoder

yjh0410 1 year ago
parent
commit
dbd1b401c8

+ 162 - 104
models/detectors/rtdetr/basic_modules/basic.py

@@ -1,8 +1,10 @@
 import math
 import copy
+
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+from torch.nn.init import constant_, xavier_uniform_
 
 
 def get_clones(module, N):
@@ -11,6 +13,10 @@ def get_clones(module, N):
     else:
         return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
 
+def inverse_sigmoid(x, eps=1e-5):
+    x = x.clamp(min=0., max=1.)
+    return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
+
 
 # ----------------- MLP modules -----------------
 class MLP(nn.Module):
@@ -44,7 +50,7 @@ class FFN(nn.Module):
         return src
     
 
-# ----------------- CNN modules -----------------
+# ----------------- Basic CNN Ops -----------------
 def get_conv2d(c1, c2, k, p, s, g, bias=False):
     conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, groups=g, bias=bias)
 
@@ -179,7 +185,8 @@ class PointwiseConv(nn.Module):
     def forward(self, x):
         return self.act(self.norm(self.conv(x)))
 
-## Yolov8's BottleNeck
+
+# ----------------- CNN Modules -----------------
 class Bottleneck(nn.Module):
     def __init__(self,
                  in_dim,
@@ -211,7 +218,6 @@ class Bottleneck(nn.Module):
 
         return x + h if self.shortcut else h
 
-# Yolov8's StageBlock
 class RTCBlock(nn.Module):
     def __init__(self,
                  in_dim,
@@ -243,60 +249,56 @@ class RTCBlock(nn.Module):
         return out
 
 
-# ----------------- Transformer modules -----------------
-## Basic ops of Deformable Attn
-def deformable_attention_core_func(value, value_spatial_shapes,
-                                   value_level_start_index, sampling_locations,
-                                   attention_weights):
-    """
-    Args:
-        value (Tensor): [bs, value_length, n_head, c]
-        value_spatial_shapes (Tensor|List): [n_levels, 2]
-        value_level_start_index (Tensor|List): [n_levels]
-        sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2]
-        attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points]
-
-    Returns:
-        output (Tensor): [bs, Length_{query}, C]
-    """
-    bs, _, n_head, c = value.shape
-    _, Len_q, _, n_levels, n_points, _ = sampling_locations.shape
-
-    split_shape = [h * w for h, w in value_spatial_shapes]
-    value_list = value.split(split_shape, axis=1)
+# ----------------- Basic Transformer Ops -----------------
+def multi_scale_deformable_attn_pytorch(
+    value: torch.Tensor,
+    value_spatial_shapes: torch.Tensor,
+    sampling_locations: torch.Tensor,
+    attention_weights: torch.Tensor,
+) -> torch.Tensor:
+
+    bs, _, num_heads, embed_dims = value.shape
+    _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
+    
+    value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
     sampling_grids = 2 * sampling_locations - 1
     sampling_value_list = []
-    for level, (h, w) in enumerate(value_spatial_shapes):
-        # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
-        value_l_ = value_list[level].flatten(2).transpose(
-            [0, 2, 1]).reshape([bs * n_head, c, h, w])
-        # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
-        sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(
-            [0, 2, 1, 3, 4]).flatten(0, 1)
-        # N_*M_, D_, Lq_, P_
+    for level, (H_, W_) in enumerate(value_spatial_shapes):
+        # bs, H_*W_, num_heads, embed_dims ->
+        # bs, H_*W_, num_heads*embed_dims ->
+        # bs, num_heads*embed_dims, H_*W_ ->
+        # bs*num_heads, embed_dims, H_, W_
+        value_l_ = (
+            value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
+        )
+        # bs, num_queries, num_heads, num_points, 2 ->
+        # bs, num_heads, num_queries, num_points, 2 ->
+        # bs*num_heads, num_queries, num_points, 2
+        sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
+        # bs*num_heads, embed_dims, num_queries, num_points
         sampling_value_l_ = F.grid_sample(
-            value_l_,
-            sampling_grid_l_,
-            mode='bilinear',
-            padding_mode='zeros',
-            align_corners=False)
+            value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
+        )
         sampling_value_list.append(sampling_value_l_)
-    # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_*M_, 1, Lq_, L_*P_)
-    attention_weights = attention_weights.transpose([0, 2, 1, 3, 4]).reshape(
-        [bs * n_head, 1, Len_q, n_levels * n_points])
-    output = (torch.stack(
-        sampling_value_list, axis=-2).flatten(-2) *
-              attention_weights).sum(-1).reshape([bs, n_head * c, Len_q])
-
-    return output.transpose([0, 2, 1])
+    # (bs, num_queries, num_heads, num_levels, num_points) ->
+    # (bs, num_heads, num_queries, num_levels, num_points) ->
+    # (bs, num_heads, 1, num_queries, num_levels*num_points)
+    attention_weights = attention_weights.transpose(1, 2).reshape(
+        bs * num_heads, 1, num_queries, num_levels * num_points
+    )
+    output = (
+        (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
+        .sum(-1)
+        .view(bs, num_heads * embed_dims, num_queries)
+    )
+    return output.transpose(1, 2).contiguous()
 
-class MSDeformableAttention(nn.Layer):
+class MSDeformableAttention(nn.Module):
     def __init__(self,
                  embed_dim=256,
                  num_heads=8,
                  num_levels=4,
-                 num_points=4,
-                 lr_mult=0.1):
+                 num_points=4):
         """
         Multi-Scale Deformable Attention Module
         """
@@ -310,55 +312,51 @@ class MSDeformableAttention(nn.Layer):
         self.head_dim = embed_dim // num_heads
         assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
 
-        self.sampling_offsets = nn.Linear(
-            embed_dim,
-            self.total_points * 2,
-            weight_attr=ParamAttr(learning_rate=lr_mult),
-            bias_attr=ParamAttr(learning_rate=lr_mult))
-
+        self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2)
         self.attention_weights = nn.Linear(embed_dim, self.total_points)
         self.value_proj = nn.Linear(embed_dim, embed_dim)
         self.output_proj = nn.Linear(embed_dim, embed_dim)
+        
         try:
             # use cuda op
             from deformable_detr_ops import ms_deformable_attn
             self.ms_deformable_attn_core = ms_deformable_attn
         except:
-            # use paddle func
-            self.ms_deformable_attn_core = deformable_attention_core_func
+            # use torch func
+            self.ms_deformable_attn_core = multi_scale_deformable_attn_pytorch
 
         self._reset_parameters()
 
     def _reset_parameters(self):
-        # sampling_offsets
-        constant_(self.sampling_offsets.weight)
-        thetas = paddle.arange(
-            self.num_heads,
-            dtype=paddle.float32) * (2.0 * math.pi / self.num_heads)
-        grid_init = paddle.stack([thetas.cos(), thetas.sin()], -1)
-        grid_init = grid_init / grid_init.abs().max(-1, keepdim=True)
-        grid_init = grid_init.reshape([self.num_heads, 1, 1, 2]).tile(
-            [1, self.num_levels, self.num_points, 1])
-        scaling = paddle.arange(
-            1, self.num_points + 1,
-            dtype=paddle.float32).reshape([1, 1, -1, 1])
-        grid_init *= scaling
-        self.sampling_offsets.bias.set_value(grid_init.flatten())
-        # attention_weights
-        constant_(self.attention_weights.weight)
-        constant_(self.attention_weights.bias)
-        # proj
-        xavier_uniform_(self.value_proj.weight)
-        constant_(self.value_proj.bias)
-        xavier_uniform_(self.output_proj.weight)
-        constant_(self.output_proj.bias)
+        """
+        Default initialization for Parameters of Module.
+        """
+        constant_(self.sampling_offsets.weight.data, 0.0)
+        thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
+            2.0 * math.pi / self.num_heads
+        )
+        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+        grid_init = (
+            (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
+            .view(self.num_heads, 1, 1, 2)
+            .repeat(1, self.num_levels, self.num_points, 1)
+        )
+        for i in range(self.num_points):
+            grid_init[:, :, i, :] *= i + 1
+        with torch.no_grad():
+            self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
+        constant_(self.attention_weights.weight.data, 0.0)
+        constant_(self.attention_weights.bias.data, 0.0)
+        xavier_uniform_(self.value_proj.weight.data)
+        constant_(self.value_proj.bias.data, 0.0)
+        xavier_uniform_(self.output_proj.weight.data)
+        constant_(self.output_proj.bias.data, 0.0)
 
     def forward(self,
                 query,
                 reference_points,
                 value,
                 value_spatial_shapes,
-                value_level_start_index,
                 value_mask=None):
         """
         Args:
@@ -367,52 +365,70 @@ class MSDeformableAttention(nn.Layer):
                 bottom-right (1, 1), including padding area
             value (Tensor): [bs, value_length, C]
             value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
-            value_level_start_index (Tensor(int64)): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...]
             value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
 
         Returns:
             output (Tensor): [bs, Length_{query}, C]
         """
-        bs, Len_q = query.shape[:2]
-        Len_v = value.shape[1]
-        assert int(value_spatial_shapes.prod(1).sum()) == Len_v
+        bs, num_query = query.shape[:2]
+        num_value = value.shape[1]
+        assert int(value_spatial_shapes.prod(1).sum()) == num_value
 
+        # Value projection
         value = self.value_proj(value)
+        # fill "0" for the padding part
         if value_mask is not None:
             value_mask = value_mask.astype(value.dtype).unsqueeze(-1)
             value *= value_mask
-        value = value.reshape([bs, Len_v, self.num_heads, self.head_dim])
+        # [bs, all_hw, 256] -> [bs, all_hw, num_head, head_dim]
+        value = value.reshape([bs, num_value, self.num_heads, -1])
 
+        # [bs, all_hw, num_head, nun_level, num_sample_point, num_offset]
         sampling_offsets = self.sampling_offsets(query).reshape(
-            [bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2])
+            [bs, num_query, self.num_heads, self.num_levels, self.num_points, 2])
+        # [bs, all_hw, num_head, nun_level*num_sample_point]
         attention_weights = self.attention_weights(query).reshape(
-            [bs, Len_q, self.num_heads, self.num_levels * self.num_points])
-        attention_weights = F.softmax(attention_weights).reshape(
-            [bs, Len_q, self.num_heads, self.num_levels, self.num_points])
+            [bs, num_query, self.num_heads, self.num_levels * self.num_points])
+        attention_weights = attention_weights.softmax(-1)
+        # [bs, all_hw, num_head, nun_level, num_sample_point]
+        attention_weights = attention_weights.reshape(
+            [bs, num_query, self.num_heads, self.num_levels, self.num_points])
 
+        # [bs, num_query, num_heads, num_levels, num_points, 2]
         if reference_points.shape[-1] == 2:
+            # reference_points   [bs, all_hw, num_sample_point, 2] -> [bs, all_hw, 1, num_sample_point, 1, 2]
+            # sampling_offsets   [bs, all_hw, nun_head, num_level, num_sample_point, 2]
+            # offset_normalizer  [4, 2] -> [1, 1, 1, num_sample_point, 1, 2]
+            # references_points + sampling_offsets
             offset_normalizer = value_spatial_shapes.flip([1]).reshape(
                 [1, 1, 1, self.num_levels, 1, 2])
-            sampling_locations = reference_points.reshape([
-                bs, Len_q, 1, self.num_levels, 1, 2
-            ]) + sampling_offsets / offset_normalizer
+            sampling_locations = (
+                reference_points[:, :, None, :, None, :]
+                + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
+            )
         elif reference_points.shape[-1] == 4:
             sampling_locations = (
-                reference_points[:, :, None, :, None, :2] + sampling_offsets /
-                self.num_points * reference_points[:, :, None, :, None, 2:] *
-                0.5)
+                reference_points[:, :, None, :, None, :2]
+                + sampling_offsets
+                / self.num_points
+                * reference_points[:, :, None, :, None, 2:]
+                * 0.5)
         else:
             raise ValueError(
                 "Last dim of reference_points must be 2 or 4, but get {} instead.".
                 format(reference_points.shape[-1]))
 
+        # Multi-scale Deformable attention
         output = self.ms_deformable_attn_core(
-            value, value_spatial_shapes, value_level_start_index,
-            sampling_locations, attention_weights)
+            value, value_spatial_shapes, sampling_locations, attention_weights)
+        
+        # Output project
         output = self.output_proj(output)
 
         return output
 
+
+# ----------------- Transformer modules -----------------
 ## Transformer Encoder layer
 class TransformerEncoderLayer(nn.Module):
     def __init__(self,
@@ -535,7 +551,7 @@ class TransformerEncoder(nn.Module):
         return src
 
 ## Transformer Decoder layer
-class TransformerDecoderLayer(nn.Module):
+class DeformableTransformerDecoderLayer(nn.Module):
     def __init__(self,
                  d_model         :int   = 256,
                  num_heads       :int   = 8,
@@ -560,7 +576,7 @@ class TransformerDecoderLayer(nn.Module):
         self.dropout1 = nn.Dropout(dropout)
         self.norm1 = nn.LayerNorm(d_model)
         ## CrossAttention
-        self.cross_attn = MSDeformableAttention(d_model, num_heads, num_levels, num_points, 1.0)
+        self.cross_attn = MSDeformableAttention(d_model, num_heads, num_levels, num_points)
         self.dropout2 = nn.Dropout(dropout)
         self.norm2 = nn.LayerNorm(d_model)
         ## FFN
@@ -574,7 +590,6 @@ class TransformerDecoderLayer(nn.Module):
                 reference_points,
                 memory,
                 memory_spatial_shapes,
-                memory_level_start_index,
                 attn_mask=None,
                 memory_mask=None,
                 query_pos_embed=None):
@@ -585,7 +600,7 @@ class TransformerDecoderLayer(nn.Module):
                 attn_mask.astype('bool'),
                 torch.zeros(attn_mask.shape, tgt.dtype),
                 torch.full(attn_mask.shape, float("-inf"), tgt.dtype))
-        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=attn_mask)
+        tgt2 = self.self_attn(q, k, value=tgt)
         tgt = tgt + self.dropout1(tgt2)
         tgt = self.norm1(tgt)
 
@@ -594,7 +609,6 @@ class TransformerDecoderLayer(nn.Module):
                                reference_points,
                                memory,
                                memory_spatial_shapes,
-                               memory_level_start_index,
                                memory_mask)
         tgt = tgt + self.dropout2(tgt2)
         tgt = self.norm2(tgt)
@@ -605,15 +619,18 @@ class TransformerDecoderLayer(nn.Module):
         return tgt
 
 ## Transformer Decoder
-class TransformerDecoder(nn.Module):
+class DeformableTransformerDecoder(nn.Module):
     def __init__(self,
                  d_model        :int   = 256,
                  num_heads      :int   = 8,
                  num_layers     :int   = 1,
+                 num_levels     :int   = 3,
+                 num_points     :int   = 4,
                  mlp_ratio      :float = 4.0,
                  pe_temperature :float = 10000.,
                  dropout        :float = 0.1,
                  act_type       :str   = "relu",
+                 return_intermediate :bool = False,
                  ):
         super().__init__()
         # ----------- Basic parameters -----------
@@ -625,8 +642,11 @@ class TransformerDecoder(nn.Module):
         self.act_type = act_type
         self.pe_temperature = pe_temperature
         self.pos_embed = None
-        # ----------- Basic parameters -----------
-        self.decoder_layers = None
+        # ----------- Network parameters -----------
+        self.decoder_layers = get_clones(
+            DeformableTransformerDecoderLayer(d_model, num_heads, num_levels, num_points, mlp_ratio, dropout, act_type), num_layers)
+        self.num_layers = num_layers
+        self.return_intermediate = return_intermediate
 
     def build_2d_sincos_position_embedding(self, w, h, embed_dim=256, temperature=10000.):
         assert embed_dim % 4 == 0, \
@@ -655,3 +675,41 @@ class TransformerDecoder(nn.Module):
 
         return pos_embed
 
+    def forward(self,
+                tgt,
+                ref_points_unact,
+                memory,
+                memory_spatial_shapes,
+                bbox_head,
+                score_head,
+                query_pos_head,
+                attn_mask=None,
+                memory_mask=None):
+        output = tgt
+        dec_out_bboxes = []
+        dec_out_logits = []
+        ref_points_detach = F.sigmoid(ref_points_unact)
+        for i, layer in enumerate(self.decoder_layers):
+            ref_points_input = ref_points_detach.unsqueeze(2)
+            query_pos_embed = query_pos_head(ref_points_detach)
+
+            output = layer(output, ref_points_input, memory,
+                           memory_spatial_shapes, attn_mask,
+                           memory_mask, query_pos_embed)
+
+            inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
+                ref_points_detach))
+
+            dec_out_logits.append(score_head[i](output))
+            if i == 0:
+                dec_out_bboxes.append(inter_ref_bbox)
+            else:
+                dec_out_bboxes.append(
+                    F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
+                        ref_points)))
+
+            ref_points = inter_ref_bbox
+            ref_points_detach = inter_ref_bbox.detach()
+
+        return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)
+

+ 122 - 0
models/detectors/rtdetr/basic_modules/dn_compoments.py

@@ -0,0 +1,122 @@
+import torch
+
+
+def inverse_sigmoid(x, eps=1e-5):
+    x = x.clamp(min=0., max=1.)
+    return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
+
+def bbox_cxcywh_to_xyxy(x):
+    cxcy, wh = torch.split(x, 2, axis=-1)
+    return torch.cat([cxcy - 0.5 * wh, cxcy + 0.5 * wh], dim=-1)
+
+def bbox_xyxy_to_cxcywh(x):
+    x1, y1, x2, y2 = x.split(4, axis=-1)
+    return torch.cat(
+        [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)], axis=-1)
+
+def get_contrastive_denoising_training_group(targets,
+                                             num_classes,
+                                             num_queries,
+                                             class_embed,
+                                             num_denoising=100,
+                                             label_noise_ratio=0.5,
+                                             box_noise_scale=1.0):
+    if num_denoising <= 0:
+        return None, None, None, None
+    num_gts = [len(t) for t in targets["labels"]]
+    max_gt_num = max(num_gts)
+    if max_gt_num == 0:
+        return None, None, None, None
+
+    num_group = num_denoising // max_gt_num
+    num_group = 1 if num_group == 0 else num_group
+
+    # pad gt to max_num of a batch
+    bs = len(targets["labels"])
+    input_query_class = torch.full([bs, max_gt_num], num_classes).long()
+    input_query_bbox = torch.zeros([bs, max_gt_num, 4])
+    pad_gt_mask = torch.zeros([bs, max_gt_num])
+    for i in range(bs):
+        num_gt = num_gts[i]
+        if num_gt > 0:
+            input_query_class[i, :num_gt] = targets["labels"][i].squeeze(-1)
+            input_query_bbox[i, :num_gt] = targets["boxes"][i]
+            pad_gt_mask[i, :num_gt] = 1
+
+    # each group has positive and negative queries.
+    input_query_class = input_query_class.repeat(1, 2 * num_group)
+    input_query_bbox = input_query_bbox.repeat(1, 2 * num_group, 1)
+    pad_gt_mask = pad_gt_mask.repeat(1, 2 * num_group)
+
+    # positive and negative mask
+    negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1])
+    negative_gt_mask[:, max_gt_num:] = 1
+    negative_gt_mask = negative_gt_mask.repeat(1, num_group, 1)
+    positive_gt_mask = 1 - negative_gt_mask
+
+    # contrastive denoising training positive index
+    positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
+    dn_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
+    dn_positive_idx = torch.split(dn_positive_idx, [n * num_group for n in num_gts])
+    
+    # total denoising queries
+    num_denoising = int(max_gt_num * 2 * num_group)
+
+    if label_noise_ratio > 0:
+        input_query_class = input_query_class.flatten()
+        pad_gt_mask = pad_gt_mask.flatten()
+        # half of bbox prob
+        mask = torch.rand(input_query_class.shape) < (label_noise_ratio * 0.5)
+        chosen_idx = torch.nonzero(mask * pad_gt_mask).squeeze(-1)
+        # randomly put a new one here
+        new_label = torch.randint_like(
+            chosen_idx, 0, num_classes, dtype=input_query_class.dtype)
+        input_query_class.scatter_(chosen_idx, new_label)
+        input_query_class = input_query_class.reshape(bs, num_denoising)
+        pad_gt_mask = pad_gt_mask.reshape(bs, num_denoising)
+
+    if box_noise_scale > 0:
+        known_bbox = bbox_cxcywh_to_xyxy(input_query_bbox)
+        diff = torch.tile(input_query_bbox[..., 2:] * 0.5,
+                           [1, 1, 2]) * box_noise_scale
+
+        rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
+        rand_part = torch.rand(input_query_bbox.shape)
+        rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (
+            1 - negative_gt_mask)
+        rand_part *= rand_sign
+        known_bbox += rand_part * diff
+        known_bbox.clip_(min=0.0, max=1.0)
+        input_query_bbox = bbox_xyxy_to_cxcywh(known_bbox)
+        input_query_bbox = inverse_sigmoid(input_query_bbox)
+
+    class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]])])
+    input_query_class = torch.gather(class_embed, 1, input_query_class.flatten())
+    input_query_class = input_query_class.reshape(bs, num_denoising, -1)
+    
+    tgt_size = num_denoising + num_queries
+    attn_mask = torch.ones([tgt_size, tgt_size]) < 0
+    # match query cannot see the reconstruction
+    attn_mask[num_denoising:, :num_denoising] = True
+    # reconstruct cannot see each other
+    for i in range(num_group):
+        if i == 0:
+            attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), max_gt_num *
+                      2 * (i + 1):num_denoising] = True
+        if i == num_group - 1:
+            attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), :max_gt_num *
+                      i * 2] = True
+        else:
+            attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), max_gt_num *
+                      2 * (i + 1):num_denoising] = True
+            attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), :max_gt_num *
+                      2 * i] = True
+    attn_mask = ~attn_mask
+    dn_meta = {
+        "dn_positive_idx": dn_positive_idx,
+        "dn_num_group": num_group,
+        "dn_num_split": [num_denoising, num_queries]
+    }
+
+    return input_query_class, input_query_bbox, attn_mask, dn_meta
+

+ 255 - 10
models/detectors/rtdetr/rtdetr_decoder.py

@@ -1,23 +1,267 @@
+import math
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+from torch.nn.init import constant_, xavier_uniform_, uniform_
+from typing import List
+
+try:
+    from .basic_modules.basic import BasicConv, MLP, DeformableTransformerDecoder
+    from .basic_modules.dn_compoments import get_contrastive_denoising_training_group
+except:
+    from  basic_modules.basic import BasicConv, MLP, DeformableTransformerDecoder
+    from  basic_modules.dn_compoments import get_contrastive_denoising_training_group
 
 
 # ----------------- Dencoder for Detection task -----------------
-## RTDETR's Transformer
-class DetDecoder(nn.Module):
-    def __init__(self, ):
+## RTDETR's Transformer for Detection task
+class RTDETRTransformer(nn.Module):
+    def __init__(self,
+                 # basic parameters
+                 in_dims        :List = [256, 512, 1024],
+                 hidden_dim     :int  = 256,
+                 strides        :List = [8, 16, 32],
+                 num_classes    :int  = 80,
+                 num_queries    :int  = 300,
+                 pos_embed_type :str  = 'sine',
+                 trainable      :bool = False,
+                 # transformer parameters
+                 num_heads      :int   = 8,
+                 num_layers     :int   = 1,
+                 num_levels     :int   = 3,
+                 num_points     :int   = 4,
+                 mlp_ratio      :float = 4.0,
+                 pe_temperature :float = 10000.,
+                 dropout        :float = 0.1,
+                 act_type       :str   = "relu",
+                 return_intermediate :bool = False,
+                 # Denoising parameters
+                 num_denoising       :int  = 100,
+                 label_noise_ratio   :float = 0.5,
+                 box_noise_scale     :float = 1.0,
+                 learnt_init_query   :bool  = True,
+                 ):
         super().__init__()
-        self.backbone = None
-        self.neck = None
-        self.fpn = None
+        # --------------- Basic setting ---------------
+        ## Basic parameters
+        self.in_dims = in_dims
+        self.strides = strides
+        self.trainable = trainable
+        self.num_queries = num_queries
+        self.pos_embed_type = pos_embed_type
+        self.num_classes = num_classes
+        self.eps = 1e-2
+        ## Transformer parameters
+        self.num_heads  = num_heads
+        self.num_layers = num_layers
+        self.num_levels = num_levels
+        self.num_points = num_points
+        self.mlp_ratio  = mlp_ratio
+        self.dropout    = dropout
+        self.act_type   = act_type
+        self.pe_temperature = pe_temperature
+        self.return_intermediate = return_intermediate
+        ## Denoising parameters
+        self.num_denoising = num_denoising
+        self.label_noise_ratio = label_noise_ratio
+        self.box_noise_scale = box_noise_scale
+        self.learnt_init_query = learnt_init_query
 
-    def forward(self, x):
-        return
+        # --------------- Network setting ---------------
+        ## Input proj layers
+        self.input_proj_layers = nn.ModuleList(
+            BasicConv(in_dims[i], hidden_dim, kernel_size=1, act_type=None, norm_type="BN")
+            for i in range(num_levels)
+        )
+
+        ## Deformable transformer decoder
+        self.transformer_decoder = DeformableTransformerDecoder(
+            d_model    = hidden_dim,
+            num_heads  = num_heads,
+            num_layers = num_layers,
+            num_levels = num_levels,
+            num_points = num_points,
+            mlp_ratio  = mlp_ratio,
+            pe_temperature = pe_temperature,
+            dropout        = dropout,
+            act_type       = act_type,
+            return_intermediate = return_intermediate
+            )
+        
+        ## Detection head for Encoder
+        self.enc_output = nn.Sequential(
+            nn.Linear(hidden_dim, hidden_dim),
+            nn.LayerNorm(hidden_dim)
+            )
+        self.enc_class_head = nn.Linear(hidden_dim, num_classes)
+        self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
+
+        ##  Detection head for Decoder
+        self.dec_class_head = nn.ModuleList([
+            nn.Linear(hidden_dim, num_classes)
+            for _ in range(num_layers)
+        ])
+        self.dec_bbox_head = nn.ModuleList([
+            MLP(hidden_dim, hidden_dim, 4, num_layers=3)
+            for _ in range(num_layers)
+        ])
+
+        ## Denoising part
+        self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim)
+
+        ## Object query
+        if learnt_init_query:
+            self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
+        self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
+
+        self._reset_parameters()
+
+    def _reset_parameters(self):
+        def _linear_init(module):
+            bound = 1 / math.sqrt(module.weight.shape[0])
+            uniform_(module.weight, -bound, bound)
+            if hasattr(module, "bias") and module.bias is not None:
+                uniform_(module.bias, -bound, bound)
+
+        # class and bbox head init
+        prior_prob = 0.01
+        cls_bias_init = float(-math.log((1 - prior_prob) / prior_prob))
+        _linear_init(self.enc_class_head)
+        constant_(self.enc_class_head.bias, cls_bias_init)
+        constant_(self.enc_bbox_head.layers[-1].weight, 0.)
+        constant_(self.enc_bbox_head.layers[-1].bias, 0.)
+        for cls_, reg_ in zip(self.dec_class_head, self.dec_bbox_head):
+            _linear_init(cls_)
+            constant_(cls_.bias, cls_bias_init)
+            constant_(reg_.layers[-1].weight, 0.)
+            constant_(reg_.layers[-1].bias, 0.)
+
+        _linear_init(self.enc_output[0])
+        xavier_uniform_(self.enc_output[0].weight)
+        if self.learnt_init_query:
+            xavier_uniform_(self.tgt_embed.weight)
+        xavier_uniform_(self.query_pos_head.layers[0].weight)
+        xavier_uniform_(self.query_pos_head.layers[1].weight)
+        for l in self.input_proj:
+            xavier_uniform_(l[0].weight)
+
+    def generate_anchors(self, spatial_shapes, grid_size=0.05):
+        anchors = []
+        for lvl, (h, w) in enumerate(spatial_shapes):
+            grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w))
+            grid_xy = torch.stack([grid_x, grid_y], dim=-1).float()
+
+            valid_WH = torch.as_tensor([w, h]).float()
+            grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
+            wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl)
+            anchors.append(torch.cat([grid_xy, wh], -1).reshape([-1, h * w, 4]))
+
+        anchors = torch.cat(anchors, 1)
+        valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
+        anchors = torch.log(anchors / (1 - anchors))
+        anchors = torch.where(valid_mask, anchors, torch.as_tensor(float("inf")))
+        
+        return anchors, valid_mask
+    
+    def get_encoder_input(self, feats):
+        # get projection features
+        proj_feats = [self.input_proj_layers[i](feat) for i, feat in enumerate(feats)]
+
+        # get encoder inputs
+        feat_flatten = []
+        spatial_shapes = []
+        level_start_index = [0, ]
+        for i, feat in enumerate(proj_feats):
+            _, _, h, w = feat.shape
+            # [b, c, h, w] -> [b, h*w, c]
+            feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
+            # [num_levels, 2]
+            spatial_shapes.append([h, w])
+            # [l], start index of each level
+            level_start_index.append(h * w + level_start_index[-1])
+
+        # [b, l, c]
+        feat_flatten = torch.cat(feat_flatten, 1)
+        level_start_index.pop()
+
+        return (feat_flatten, spatial_shapes, level_start_index)
+
+    def get_decoder_input(self,
+                          memory,
+                          spatial_shapes,
+                          denoising_class=None,
+                          denoising_bbox_unact=None):
+        bs, _, _ = memory.shape
+        # prepare input for decoder
+        anchors, valid_mask = self.generate_anchors(spatial_shapes)
+        memory = torch.where(valid_mask, memory, torch.as_tensor(0.))
+        output_memory = self.enc_output(memory)
+
+        enc_outputs_class = self.enc_class_head(output_memory)
+        enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
+
+        topk = self.num_queries
+        topk_ind = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
+        reference_points_unact = torch.gather(enc_outputs_coord_unact, 1, topk_ind.unsqueeze(-1).repeat(1, 1, 4))
+        enc_topk_bboxes = F.sigmoid(reference_points_unact)
+
+        if denoising_bbox_unact is not None:
+            reference_points_unact = torch.cat(
+                [denoising_bbox_unact, reference_points_unact], 1)
+        if self.trainable:
+            reference_points_unact = reference_points_unact.detach()
+        enc_topk_logits = torch.gather(enc_outputs_class, 1, topk_ind.unsqueeze(-1).repeat(1, 1, self.num_classes))
+
+        # extract region features
+        if self.learnt_init_query:
+            target = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
+        else:
+            target = torch.gather(output_memory, 1, topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[1]))
+            if self.training:
+                target = target.detach()
+        if denoising_class is not None:
+            target = torch.cat([denoising_class, target], dim=1)
+
+        return target, reference_points_unact, enc_topk_bboxes, enc_topk_logits
+    
+    def forward(self, feats, gt_meta=None):
+        # input projection and embedding
+        memory, spatial_shapes, _ = self.get_encoder_input(feats)
+
+        # prepare denoising training
+        if self.training:
+            denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
+                get_contrastive_denoising_training_group(gt_meta,
+                                                         self.num_classes,
+                                                         self.num_queries,
+                                                         self.denoising_class_embed.weight,
+                                                         self.num_denoising,
+                                                         self.label_noise_ratio,
+                                                         self.box_noise_scale)
+        else:
+            denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
+
+        target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
+            self.get_decoder_input(
+            memory, spatial_shapes, denoising_class, denoising_bbox_unact)
+
+        # decoder
+        out_bboxes, out_logits = self.transformer_decoder(target,
+                                                          init_ref_points_unact,
+                                                          memory,
+                                                          spatial_shapes,
+                                                          self.dec_bbox_head,
+                                                          self.dec_class_head,
+                                                          self.query_pos_head,
+                                                          attn_mask)
+        
+        return (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits,
+                dn_meta)
 
 
 # ----------------- Dencoder for Segmentation task -----------------
-class SegDecoder(nn.Module):
+## RTDETR's Transformer for Segmentation task
+class SegTransformerDecoder(nn.Module):
     def __init__(self, ):
         super().__init__()
         # TODO: design seg-decoder
@@ -27,7 +271,8 @@ class SegDecoder(nn.Module):
 
 
 # ----------------- Dencoder for Pose estimation task -----------------
-class PosDecoder(nn.Module):
+## RTDETR's Transformer for Pose estimation task
+class PosTransformerDecoder(nn.Module):
     def __init__(self, ):
         super().__init__()
         # TODO: design seg-decoder