Przeglądaj źródła

complete rtdetr model

yjh0410 1 rok temu
rodzic
commit
c7174f216d

+ 10 - 0
README.md

@@ -20,10 +20,20 @@ conda activate rtcdet
 ```
 
 - Requirements:
+1. Install necessary libraies
 ```Shell
 pip install -r requirements.txt 
 ```
 
+2. (optional) Compile MSDeformableAttention ops for DETR series
+
+```bash
+cd ./ppdet/modeling/transformers/ext_op/
+
+python setup_ms_deformable_attn_op.py install
+```
+See [details](./models/detectors/rtdetr/basic_modules/ext_op/)
+
 My environment:
 - PyTorch = 1.9.1
 - Torchvision = 0.10.1

+ 11 - 0
README_CN.md

@@ -18,10 +18,21 @@ conda activate rtcdet
 ```
 
 - 接着,配置环境:
+1. 首先安装基础环境
 ```Shell
 pip install -r requirements.txt 
 ```
 
+2. (可选) 其次,可以考虑编译CUDA版本的 MSDeformableAttention 算子,以便使用DETR系列的检测器
+
+```bash
+cd ./ppdet/modeling/transformers/ext_op/
+
+python setup_ms_deformable_attn_op.py install
+```
+See [details](./models/detectors/rtdetr/basic_modules/ext_op/)
+
+
 项目作者所使用的环境配置:
 - PyTorch = 1.9.1
 - Torchvision = 0.10.1

+ 5 - 0
models/detectors/__init__.py

@@ -13,6 +13,7 @@ from .yolov8.build import build_yolov8
 from .yolox.build import build_yolox
 # My RTCDet series
 from .rtcdet.build import build_rtcdet
+from .rtdetr.build import build_rtdetr
 
 
 # build object detector
@@ -66,6 +67,10 @@ def build_model(args,
     elif args.model in ['rtcdet_n', 'rtcdet_t', 'rtcdet_s', 'rtcdet_m', 'rtcdet_l', 'rtcdet_x']:
         model, criterion = build_rtcdet(
             args, model_cfg, device, num_classes, trainable, deploy)
+    # RT-DETR
+    elif args.model in ['rtdetr_r18', 'rtdetr_r34', 'rtdetr_r50', 'rtdetr_r101']:
+        model, criterion = build_rtdetr(
+            args, model_cfg, num_classes, trainable, deploy)
 
     if trainable:
         # Load pretrained weight

+ 0 - 482
models/detectors/rtdetr/basic_modules/basic.py

@@ -1,21 +1,5 @@
-import math
-import copy
-
 import torch
 import torch.nn as nn
-import torch.nn.functional as F
-from torch.nn.init import constant_, xavier_uniform_
-
-
-def get_clones(module, N):
-    if N <= 0:
-        return None
-    else:
-        return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
-
-def inverse_sigmoid(x, eps=1e-5):
-    x = x.clamp(min=0., max=1.)
-    return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
 
 
 # ----------------- MLP modules -----------------
@@ -247,469 +231,3 @@ class RTCBlock(nn.Module):
         out = self.output_proj(torch.cat(out, dim=1))
 
         return out
-
-
-# ----------------- Basic Transformer Ops -----------------
-def multi_scale_deformable_attn_pytorch(
-    value: torch.Tensor,
-    value_spatial_shapes: torch.Tensor,
-    sampling_locations: torch.Tensor,
-    attention_weights: torch.Tensor,
-) -> torch.Tensor:
-
-    bs, _, num_heads, embed_dims = value.shape
-    _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
-    
-    value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
-    sampling_grids = 2 * sampling_locations - 1
-    sampling_value_list = []
-    for level, (H_, W_) in enumerate(value_spatial_shapes):
-        # bs, H_*W_, num_heads, embed_dims ->
-        # bs, H_*W_, num_heads*embed_dims ->
-        # bs, num_heads*embed_dims, H_*W_ ->
-        # bs*num_heads, embed_dims, H_, W_
-        value_l_ = (
-            value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
-        )
-        # bs, num_queries, num_heads, num_points, 2 ->
-        # bs, num_heads, num_queries, num_points, 2 ->
-        # bs*num_heads, num_queries, num_points, 2
-        sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
-        # bs*num_heads, embed_dims, num_queries, num_points
-        sampling_value_l_ = F.grid_sample(
-            value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
-        )
-        sampling_value_list.append(sampling_value_l_)
-    # (bs, num_queries, num_heads, num_levels, num_points) ->
-    # (bs, num_heads, num_queries, num_levels, num_points) ->
-    # (bs, num_heads, 1, num_queries, num_levels*num_points)
-    attention_weights = attention_weights.transpose(1, 2).reshape(
-        bs * num_heads, 1, num_queries, num_levels * num_points
-    )
-    output = (
-        (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
-        .sum(-1)
-        .view(bs, num_heads * embed_dims, num_queries)
-    )
-    return output.transpose(1, 2).contiguous()
-
-class MSDeformableAttention(nn.Module):
-    def __init__(self,
-                 embed_dim=256,
-                 num_heads=8,
-                 num_levels=4,
-                 num_points=4):
-        """
-        Multi-Scale Deformable Attention Module
-        """
-        super(MSDeformableAttention, self).__init__()
-        self.embed_dim = embed_dim
-        self.num_heads = num_heads
-        self.num_levels = num_levels
-        self.num_points = num_points
-        self.total_points = num_heads * num_levels * num_points
-
-        self.head_dim = embed_dim // num_heads
-        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
-
-        self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2)
-        self.attention_weights = nn.Linear(embed_dim, self.total_points)
-        self.value_proj = nn.Linear(embed_dim, embed_dim)
-        self.output_proj = nn.Linear(embed_dim, embed_dim)
-        
-        try:
-            # use cuda op
-            from deformable_detr_ops import ms_deformable_attn
-            self.ms_deformable_attn_core = ms_deformable_attn
-        except:
-            # use torch func
-            self.ms_deformable_attn_core = multi_scale_deformable_attn_pytorch
-
-        self._reset_parameters()
-
-    def _reset_parameters(self):
-        """
-        Default initialization for Parameters of Module.
-        """
-        constant_(self.sampling_offsets.weight.data, 0.0)
-        thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
-            2.0 * math.pi / self.num_heads
-        )
-        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
-        grid_init = (
-            (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
-            .view(self.num_heads, 1, 1, 2)
-            .repeat(1, self.num_levels, self.num_points, 1)
-        )
-        for i in range(self.num_points):
-            grid_init[:, :, i, :] *= i + 1
-        with torch.no_grad():
-            self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
-        constant_(self.attention_weights.weight.data, 0.0)
-        constant_(self.attention_weights.bias.data, 0.0)
-        xavier_uniform_(self.value_proj.weight.data)
-        constant_(self.value_proj.bias.data, 0.0)
-        xavier_uniform_(self.output_proj.weight.data)
-        constant_(self.output_proj.bias.data, 0.0)
-
-    def forward(self,
-                query,
-                reference_points,
-                value,
-                value_spatial_shapes,
-                value_mask=None):
-        """
-        Args:
-            query (Tensor): [bs, query_length, C]
-            reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
-                bottom-right (1, 1), including padding area
-            value (Tensor): [bs, value_length, C]
-            value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
-            value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
-
-        Returns:
-            output (Tensor): [bs, Length_{query}, C]
-        """
-        bs, num_query = query.shape[:2]
-        num_value = value.shape[1]
-        assert int(value_spatial_shapes.prod(1).sum()) == num_value
-
-        # Value projection
-        value = self.value_proj(value)
-        # fill "0" for the padding part
-        if value_mask is not None:
-            value_mask = value_mask.astype(value.dtype).unsqueeze(-1)
-            value *= value_mask
-        # [bs, all_hw, 256] -> [bs, all_hw, num_head, head_dim]
-        value = value.reshape([bs, num_value, self.num_heads, -1])
-
-        # [bs, all_hw, num_head, nun_level, num_sample_point, num_offset]
-        sampling_offsets = self.sampling_offsets(query).reshape(
-            [bs, num_query, self.num_heads, self.num_levels, self.num_points, 2])
-        # [bs, all_hw, num_head, nun_level*num_sample_point]
-        attention_weights = self.attention_weights(query).reshape(
-            [bs, num_query, self.num_heads, self.num_levels * self.num_points])
-        attention_weights = attention_weights.softmax(-1)
-        # [bs, all_hw, num_head, nun_level, num_sample_point]
-        attention_weights = attention_weights.reshape(
-            [bs, num_query, self.num_heads, self.num_levels, self.num_points])
-
-        # [bs, num_query, num_heads, num_levels, num_points, 2]
-        if reference_points.shape[-1] == 2:
-            # reference_points   [bs, all_hw, num_sample_point, 2] -> [bs, all_hw, 1, num_sample_point, 1, 2]
-            # sampling_offsets   [bs, all_hw, nun_head, num_level, num_sample_point, 2]
-            # offset_normalizer  [4, 2] -> [1, 1, 1, num_sample_point, 1, 2]
-            # references_points + sampling_offsets
-            offset_normalizer = value_spatial_shapes.flip([1]).reshape(
-                [1, 1, 1, self.num_levels, 1, 2])
-            sampling_locations = (
-                reference_points[:, :, None, :, None, :]
-                + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
-            )
-        elif reference_points.shape[-1] == 4:
-            sampling_locations = (
-                reference_points[:, :, None, :, None, :2]
-                + sampling_offsets
-                / self.num_points
-                * reference_points[:, :, None, :, None, 2:]
-                * 0.5)
-        else:
-            raise ValueError(
-                "Last dim of reference_points must be 2 or 4, but get {} instead.".
-                format(reference_points.shape[-1]))
-
-        # Multi-scale Deformable attention
-        output = self.ms_deformable_attn_core(
-            value, value_spatial_shapes, sampling_locations, attention_weights)
-        
-        # Output project
-        output = self.output_proj(output)
-
-        return output
-
-
-# ----------------- Transformer modules -----------------
-## Transformer Encoder layer
-class TransformerEncoderLayer(nn.Module):
-    def __init__(self,
-                 d_model         :int   = 256,
-                 num_heads       :int   = 8,
-                 mlp_ratio       :float = 4.0,
-                 dropout         :float = 0.1,
-                 act_type        :str   = "relu",
-                 ):
-        super().__init__()
-        # ----------- Basic parameters -----------
-        self.d_model = d_model
-        self.num_heads = num_heads
-        self.mlp_ratio = mlp_ratio
-        self.dropout = dropout
-        self.act_type = act_type
-        # ----------- Basic parameters -----------
-        # Multi-head Self-Attn
-        self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
-        self.dropout = nn.Dropout(dropout)
-        self.norm = nn.LayerNorm(d_model)
-
-        # Feedforwaed Network
-        self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
-
-    def with_pos_embed(self, tensor, pos):
-        return tensor if pos is None else tensor + pos
-
-
-    def forward(self, src, pos_embed):
-        """
-        Input:
-            src:       [torch.Tensor] -> [B, N, C]
-            pos_embed: [torch.Tensor] -> [B, N, C]
-        Output:
-            src:       [torch.Tensor] -> [B, N, C]
-        """
-        q = k = self.with_pos_embed(src, pos_embed)
-
-        # -------------- MHSA --------------
-        src2 = self.self_attn(q, k, value=src)[0]
-        src = src + self.dropout(src2)
-        src = self.norm(src)
-
-        # -------------- FFN --------------
-        src = self.ffn(src)
-        
-        return src
-
-## Transformer Encoder
-class TransformerEncoder(nn.Module):
-    def __init__(self,
-                 d_model        :int   = 256,
-                 num_heads      :int   = 8,
-                 num_layers     :int   = 1,
-                 mlp_ratio      :float = 4.0,
-                 pe_temperature : float = 10000.,
-                 dropout        :float = 0.1,
-                 act_type       :str   = "relu",
-                 ):
-        super().__init__()
-        # ----------- Basic parameters -----------
-        self.d_model = d_model
-        self.num_heads = num_heads
-        self.num_layers = num_layers
-        self.mlp_ratio = mlp_ratio
-        self.dropout = dropout
-        self.act_type = act_type
-        self.pe_temperature = pe_temperature
-        self.pos_embed = None
-        # ----------- Basic parameters -----------
-        self.encoder_layers = get_clones(
-            TransformerEncoderLayer(d_model, num_heads, mlp_ratio, dropout, act_type), num_layers)
-
-    def build_2d_sincos_position_embedding(self, w, h, embed_dim=256, temperature=10000.):
-        assert embed_dim % 4 == 0, \
-            'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
-        
-        # ----------- Check cahed pos_embed -----------
-        if self.pos_embed is not None and \
-            self.pos_embed.shape[2:] == [h, w]:
-            return self.pos_embed
-        
-        # ----------- Generate grid coords -----------
-        grid_w = torch.arange(int(w), dtype=torch.float32)
-        grid_h = torch.arange(int(h), dtype=torch.float32)
-        grid_w, grid_h = torch.meshgrid([grid_w, grid_h])  # shape: [H, W]
-
-        pos_dim = embed_dim // 4
-        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
-        omega = 1. / (temperature**omega)
-
-        out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
-        out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
-
-        # shape: [1, N, C]
-        pos_embed = torch.concat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], axis=1)[None, :, :]
-        self.pos_embed = pos_embed
-
-        return pos_embed
-
-    def forward(self, src):
-        """
-        Input:
-            src:       [torch.Tensor] -> [B, C, H, W]
-        Output:
-            src:       [torch.Tensor] -> [B, N, C]
-        """
-        # -------- Transformer encoder --------
-        for encoder in self.encoder_layers:
-            channels, fmp_h, fmp_w = src.shape[1:]
-            # [B, C, H, W] -> [B, N, C], N=HxW
-            src_flatten = src.flatten(2).permute(0, 2, 1)
-            pos_embed = self.build_2d_sincos_position_embedding(
-                    fmp_w, fmp_h, channels, self.pe_temperature)
-            memory = encoder(src_flatten, pos_embed=pos_embed)
-            # [B, N, C] -> [B, C, N] -> [B, C, H, W]
-            src = memory.permute(0, 2, 1).reshape([-1, channels, fmp_h, fmp_w])
-
-        return src
-
-## Transformer Decoder layer
-class DeformableTransformerDecoderLayer(nn.Module):
-    def __init__(self,
-                 d_model         :int   = 256,
-                 num_heads       :int   = 8,
-                 num_levels      :int   = 3,
-                 num_points      :int   = 4,
-                 mlp_ratio       :float = 4.0,
-                 dropout         :float = 0.1,
-                 act_type        :str   = "relu",
-                 ):
-        super().__init__()
-        # ----------- Basic parameters -----------
-        self.d_model = d_model
-        self.num_heads = num_heads
-        self.num_levels = num_levels
-        self.num_points = num_points
-        self.mlp_ratio = mlp_ratio
-        self.dropout = dropout
-        self.act_type = act_type
-        # ---------------- Network parameters ----------------
-        ## Multi-head Self-Attn
-        self.self_attn  = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
-        self.dropout1 = nn.Dropout(dropout)
-        self.norm1 = nn.LayerNorm(d_model)
-        ## CrossAttention
-        self.cross_attn = MSDeformableAttention(d_model, num_heads, num_levels, num_points)
-        self.dropout2 = nn.Dropout(dropout)
-        self.norm2 = nn.LayerNorm(d_model)
-        ## FFN
-        self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
-
-    def with_pos_embed(self, tensor, pos):
-        return tensor if pos is None else tensor + pos
-
-    def forward(self,
-                tgt,
-                reference_points,
-                memory,
-                memory_spatial_shapes,
-                attn_mask=None,
-                memory_mask=None,
-                query_pos_embed=None):
-        # ---------------- MSHA for Object Query -----------------
-        q = k = self.with_pos_embed(tgt, query_pos_embed)
-        if attn_mask is not None:
-            attn_mask = torch.where(
-                attn_mask.astype('bool'),
-                torch.zeros(attn_mask.shape, tgt.dtype),
-                torch.full(attn_mask.shape, float("-inf"), tgt.dtype))
-        tgt2 = self.self_attn(q, k, value=tgt)
-        tgt = tgt + self.dropout1(tgt2)
-        tgt = self.norm1(tgt)
-
-        # ---------------- CMHA for Object Query and Image-feature -----------------
-        tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos_embed),
-                               reference_points,
-                               memory,
-                               memory_spatial_shapes,
-                               memory_mask)
-        tgt = tgt + self.dropout2(tgt2)
-        tgt = self.norm2(tgt)
-
-        # ---------------- FeedForward Network -----------------
-        tgt = self.ffn(tgt)
-
-        return tgt
-
-## Transformer Decoder
-class DeformableTransformerDecoder(nn.Module):
-    def __init__(self,
-                 d_model        :int   = 256,
-                 num_heads      :int   = 8,
-                 num_layers     :int   = 1,
-                 num_levels     :int   = 3,
-                 num_points     :int   = 4,
-                 mlp_ratio      :float = 4.0,
-                 pe_temperature :float = 10000.,
-                 dropout        :float = 0.1,
-                 act_type       :str   = "relu",
-                 return_intermediate :bool = False,
-                 ):
-        super().__init__()
-        # ----------- Basic parameters -----------
-        self.d_model = d_model
-        self.num_heads = num_heads
-        self.num_layers = num_layers
-        self.mlp_ratio = mlp_ratio
-        self.dropout = dropout
-        self.act_type = act_type
-        self.pe_temperature = pe_temperature
-        self.pos_embed = None
-        # ----------- Network parameters -----------
-        self.decoder_layers = get_clones(
-            DeformableTransformerDecoderLayer(d_model, num_heads, num_levels, num_points, mlp_ratio, dropout, act_type), num_layers)
-        self.num_layers = num_layers
-        self.return_intermediate = return_intermediate
-
-    def build_2d_sincos_position_embedding(self, w, h, embed_dim=256, temperature=10000.):
-        assert embed_dim % 4 == 0, \
-            'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
-        
-        # ----------- Check cahed pos_embed -----------
-        if self.pos_embed is not None and \
-            self.pos_embed.shape[2:] == [h, w]:
-            return self.pos_embed
-        
-        # ----------- Generate grid coords -----------
-        grid_w = torch.arange(int(w), dtype=torch.float32)
-        grid_h = torch.arange(int(h), dtype=torch.float32)
-        grid_w, grid_h = torch.meshgrid([grid_w, grid_h])  # shape: [H, W]
-
-        pos_dim = embed_dim // 4
-        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
-        omega = 1. / (temperature**omega)
-
-        out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
-        out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
-
-        # shape: [1, N, C]
-        pos_embed = torch.concat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], axis=1)[None, :, :]
-        self.pos_embed = pos_embed
-
-        return pos_embed
-
-    def forward(self,
-                tgt,
-                ref_points_unact,
-                memory,
-                memory_spatial_shapes,
-                bbox_head,
-                score_head,
-                query_pos_head,
-                attn_mask=None,
-                memory_mask=None):
-        output = tgt
-        dec_out_bboxes = []
-        dec_out_logits = []
-        ref_points_detach = F.sigmoid(ref_points_unact)
-        for i, layer in enumerate(self.decoder_layers):
-            ref_points_input = ref_points_detach.unsqueeze(2)
-            query_pos_embed = query_pos_head(ref_points_detach)
-
-            output = layer(output, ref_points_input, memory,
-                           memory_spatial_shapes, attn_mask,
-                           memory_mask, query_pos_embed)
-
-            inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
-                ref_points_detach))
-
-            dec_out_logits.append(score_head[i](output))
-            if i == 0:
-                dec_out_bboxes.append(inter_ref_bbox)
-            else:
-                dec_out_bboxes.append(
-                    F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
-                        ref_points)))
-
-            ref_points = inter_ref_bbox
-            ref_points_detach = inter_ref_bbox.detach()
-
-        return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)
-

+ 29 - 16
models/detectors/rtdetr/basic_modules/dn_compoments.py

@@ -6,13 +6,16 @@ def inverse_sigmoid(x, eps=1e-5):
     return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
 
 def bbox_cxcywh_to_xyxy(x):
-    cxcy, wh = torch.split(x, 2, axis=-1)
-    return torch.cat([cxcy - 0.5 * wh, cxcy + 0.5 * wh], dim=-1)
+    x_c, y_c, w, h = x.unbind(-1)
+    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
+         (x_c + 0.5 * w), (y_c + 0.5 * h)]
+    return torch.stack(b, dim=-1)
 
 def bbox_xyxy_to_cxcywh(x):
-    x1, y1, x2, y2 = x.split(4, axis=-1)
-    return torch.cat(
-        [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)], axis=-1)
+    x0, y0, x1, y1 = x.unbind(-1)
+    b = [(x0 + x1) / 2, (y0 + y1) / 2,
+         (x1 - x0), (y1 - y0)]
+    return torch.stack(b, dim=-1)
 
 def get_contrastive_denoising_training_group(targets,
                                              num_classes,
@@ -23,7 +26,7 @@ def get_contrastive_denoising_training_group(targets,
                                              box_noise_scale=1.0):
     if num_denoising <= 0:
         return None, None, None, None
-    num_gts = [len(t) for t in targets["labels"]]
+    num_gts = [len(t["labels"]) for t in targets]
     max_gt_num = max(num_gts)
     if max_gt_num == 0:
         return None, None, None, None
@@ -32,20 +35,22 @@ def get_contrastive_denoising_training_group(targets,
     num_group = 1 if num_group == 0 else num_group
 
     # pad gt to max_num of a batch
-    bs = len(targets["labels"])
+    bs = len(targets)
+    # [bs, max_gt_num]
     input_query_class = torch.full([bs, max_gt_num], num_classes).long()
+    # [bs, max_gt_num, 4]
     input_query_bbox = torch.zeros([bs, max_gt_num, 4])
     pad_gt_mask = torch.zeros([bs, max_gt_num])
     for i in range(bs):
         num_gt = num_gts[i]
         if num_gt > 0:
-            input_query_class[i, :num_gt] = targets["labels"][i].squeeze(-1)
-            input_query_bbox[i, :num_gt] = targets["boxes"][i]
+            input_query_class[i, :num_gt] = targets[i]["labels"].squeeze(-1)
+            input_query_bbox[i, :num_gt] = targets[i]["boxes"]
             pad_gt_mask[i, :num_gt] = 1
 
     # each group has positive and negative queries.
-    input_query_class = input_query_class.repeat(1, 2 * num_group)
-    input_query_bbox = input_query_bbox.repeat(1, 2 * num_group, 1)
+    input_query_class = input_query_class.repeat(1, 2 * num_group)  # [bs, 2*num_denoising], num_denoising = 2 * num_group * max_gt_num
+    input_query_bbox = input_query_bbox.repeat(1, 2 * num_group, 1) # [bs, 2*num_denoising, 4]
     pad_gt_mask = pad_gt_mask.repeat(1, 2 * num_group)
 
     # positive and negative mask
@@ -60,10 +65,10 @@ def get_contrastive_denoising_training_group(targets,
     dn_positive_idx = torch.split(dn_positive_idx, [n * num_group for n in num_gts])
     
     # total denoising queries
-    num_denoising = int(max_gt_num * 2 * num_group)
+    num_denoising = int(max_gt_num * 2 * num_group)  # num_denoising *= 2
 
     if label_noise_ratio > 0:
-        input_query_class = input_query_class.flatten()
+        input_query_class = input_query_class.flatten()  # [bs * num_denoising]
         pad_gt_mask = pad_gt_mask.flatten()
         # half of bbox prob
         mask = torch.rand(input_query_class.shape) < (label_noise_ratio * 0.5)
@@ -71,7 +76,10 @@ def get_contrastive_denoising_training_group(targets,
         # randomly put a new one here
         new_label = torch.randint_like(
             chosen_idx, 0, num_classes, dtype=input_query_class.dtype)
-        input_query_class.scatter_(chosen_idx, new_label)
+        # [bs * num_denoising]
+        input_query_class = torch.scatter(input_query_class, 0, chosen_idx, new_label)
+        # input_query_class.scatter_(chosen_idx, new_label)
+        # [bs * num_denoising] -> # [bs, num_denoising]
         input_query_class = input_query_class.reshape(bs, num_denoising)
         pad_gt_mask = pad_gt_mask.reshape(bs, num_denoising)
 
@@ -86,12 +94,17 @@ def get_contrastive_denoising_training_group(targets,
             1 - negative_gt_mask)
         rand_part *= rand_sign
         known_bbox += rand_part * diff
-        known_bbox.clip_(min=0.0, max=1.0)
+        known_bbox.clamp_(min=0.0, max=1.0)
         input_query_bbox = bbox_xyxy_to_cxcywh(known_bbox)
         input_query_bbox = inverse_sigmoid(input_query_bbox)
 
+    # [num_classes + 1, hidden_dim]
     class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]])])
-    input_query_class = torch.gather(class_embed, 1, input_query_class.flatten())
+    # input_query_class = paddle.gather(class_embed, input_query_class.flatten(), axis=0)
+
+    # input_query_class: [bs, num_denoising] -> [bs*num_denoising, hidden_dim]
+    input_query_class = torch.torch.index_select(class_embed, 0, input_query_class.flatten())
+    # [bs*num_denoising, hidden_dim] -> [bs, num_denoising, hidden_dim]
     input_query_class = input_query_class.reshape(bs, num_denoising, -1)
     
     tgt_size = num_denoising + num_queries

+ 85 - 0
models/detectors/rtdetr/basic_modules/ext_op/README.md

@@ -0,0 +1,85 @@
+# Multi-scale deformable attention自定义OP编译
+该自定义OP是参考[自定义外部算子](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/custom_op/new_cpp_op_cn.html) 。
+
+## 1. 环境依赖
+- Paddle >= 2.3.2
+- gcc 8.2
+
+## 2. 安装
+请在当前路径下进行编译安装
+```
+cd rtdetr_paddle/ppdet/modeling/transformers/ext_op/
+python setup_ms_deformable_attn_op.py install
+```
+
+编译完成后即可使用,以下为`ms_deformable_attn`的使用示例
+```
+# 引入自定义op
+from deformable_detr_ops import ms_deformable_attn
+
+# 构造fake input tensor
+bs, n_heads, c = 2, 8, 8
+query_length, n_levels, n_points = 2, 2, 2
+spatial_shapes = paddle.to_tensor([(6, 4), (3, 2)], dtype=paddle.int64)
+level_start_index = paddle.concat((paddle.to_tensor(
+    [0], dtype=paddle.int64), spatial_shapes.prod(1).cumsum(0)[:-1]))
+value_length = sum([(H * W).item() for H, W in spatial_shapes])
+
+def get_test_tensors(channels):
+    value = paddle.rand(
+        [bs, value_length, n_heads, channels], dtype=paddle.float32) * 0.01
+    sampling_locations = paddle.rand(
+        [bs, query_length, n_heads, n_levels, n_points, 2],
+        dtype=paddle.float32)
+    attention_weights = paddle.rand(
+        [bs, query_length, n_heads, n_levels, n_points],
+        dtype=paddle.float32) + 1e-5
+    attention_weights /= attention_weights.sum(-1, keepdim=True).sum(
+        -2, keepdim=True)
+    return [value, sampling_locations, attention_weights]
+
+value, sampling_locations, attention_weights = get_test_tensors(c)
+
+output = ms_deformable_attn(value,
+                            spatial_shapes,
+                            level_start_index,
+                            sampling_locations,
+                            attention_weights)
+```
+
+## 3. 单元测试
+可以通过执行单元测试来确认自定义算子功能的正确性,执行单元测试的示例如下所示:
+```
+python test_ms_deformable_attn_op.py
+```
+运行成功后,打印如下:
+```
+*True check_forward_equal_with_paddle_float: max_abs_err 6.98e-10 max_rel_err 2.03e-07
+*tensor1 True check_gradient_numerical(D=30)
+*tensor2 True check_gradient_numerical(D=30)
+*tensor3 True check_gradient_numerical(D=30)
+*tensor1 True check_gradient_numerical(D=32)
+*tensor2 True check_gradient_numerical(D=32)
+*tensor3 True check_gradient_numerical(D=32)
+*tensor1 True check_gradient_numerical(D=64)
+*tensor2 True check_gradient_numerical(D=64)
+*tensor3 True check_gradient_numerical(D=64)
+*tensor1 True check_gradient_numerical(D=71)
+*tensor2 True check_gradient_numerical(D=71)
+*tensor3 True check_gradient_numerical(D=71)
+*tensor1 True check_gradient_numerical(D=128)
+*tensor2 True check_gradient_numerical(D=128)
+*tensor3 True check_gradient_numerical(D=128)
+*tensor1 True check_gradient_numerical(D=1024)
+*tensor2 True check_gradient_numerical(D=1024)
+*tensor3 True check_gradient_numerical(D=1024)
+*tensor1 True check_gradient_numerical(D=1025)
+*tensor2 True check_gradient_numerical(D=1025)
+*tensor3 True check_gradient_numerical(D=1025)
+*tensor1 True check_gradient_numerical(D=2048)
+*tensor2 True check_gradient_numerical(D=2048)
+*tensor3 True check_gradient_numerical(D=2048)
+*tensor1 True check_gradient_numerical(D=3096)
+*tensor2 True check_gradient_numerical(D=3096)
+*tensor3 True check_gradient_numerical(D=3096)
+```

+ 65 - 0
models/detectors/rtdetr/basic_modules/ext_op/ms_deformable_attn_op.cc

@@ -0,0 +1,65 @@
+/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/extension.h"
+
+#include <vector>
+
+// declare GPU implementation
+std::vector<paddle::Tensor>
+MSDeformableAttnCUDAForward(const paddle::Tensor &value,
+                            const paddle::Tensor &value_spatial_shapes,
+                            const paddle::Tensor &value_level_start_index,
+                            const paddle::Tensor &sampling_locations,
+                            const paddle::Tensor &attention_weights);
+
+std::vector<paddle::Tensor> MSDeformableAttnCUDABackward(
+    const paddle::Tensor &value, const paddle::Tensor &value_spatial_shapes,
+    const paddle::Tensor &value_level_start_index,
+    const paddle::Tensor &sampling_locations,
+    const paddle::Tensor &attention_weights, const paddle::Tensor &grad_out);
+
+//// CPU not implemented
+
+std::vector<std::vector<int64_t>>
+MSDeformableAttnInferShape(std::vector<int64_t> value_shape,
+                           std::vector<int64_t> value_spatial_shapes_shape,
+                           std::vector<int64_t> value_level_start_index_shape,
+                           std::vector<int64_t> sampling_locations_shape,
+                           std::vector<int64_t> attention_weights_shape) {
+  return {{value_shape[0], sampling_locations_shape[1],
+           value_shape[2] * value_shape[3]}};
+}
+
+std::vector<paddle::DataType>
+MSDeformableAttnInferDtype(paddle::DataType value_dtype,
+                           paddle::DataType value_spatial_shapes_dtype,
+                           paddle::DataType value_level_start_index_dtype,
+                           paddle::DataType sampling_locations_dtype,
+                           paddle::DataType attention_weights_dtype) {
+  return {value_dtype};
+}
+
+PD_BUILD_OP(ms_deformable_attn)
+    .Inputs({"Value", "SpatialShapes", "LevelIndex", "SamplingLocations",
+             "AttentionWeights"})
+    .Outputs({"Out"})
+    .SetKernelFn(PD_KERNEL(MSDeformableAttnCUDAForward))
+    .SetInferShapeFn(PD_INFER_SHAPE(MSDeformableAttnInferShape))
+    .SetInferDtypeFn(PD_INFER_DTYPE(MSDeformableAttnInferDtype));
+
+PD_BUILD_GRAD_OP(ms_deformable_attn)
+    .Inputs({"Value", "SpatialShapes", "LevelIndex", "SamplingLocations",
+             "AttentionWeights", paddle::Grad("Out")})
+    .Outputs({paddle::Grad("Value"), paddle::Grad("SpatialShapes"),
+              paddle::Grad("LevelIndex"), paddle::Grad("SamplingLocations"),
+              paddle::Grad("AttentionWeights")})
+    .SetKernelFn(PD_KERNEL(MSDeformableAttnCUDABackward));

+ 1073 - 0
models/detectors/rtdetr/basic_modules/ext_op/ms_deformable_attn_op.cu

@@ -0,0 +1,1073 @@
+/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/extension.h"
+
+#define CUDA_KERNEL_LOOP(i, n)                                                 \
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n);                 \
+       i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N, const int num_threads) {
+  return (N + num_threads - 1) / num_threads;
+}
+
+// forward bilinear
+template <typename data_t>
+__device__ data_t deformable_attn_bilinear_forward(
+    const data_t *&bottom_data, const int &height, const int &width,
+    const int &nheads, const int &channels, const data_t &h, const data_t &w,
+    const int &m, const int &c) {
+  const int h_low = floor(h);
+  const int w_low = floor(w);
+  const int h_high = h_low + 1;
+  const int w_high = w_low + 1;
+
+  const data_t lh = h - h_low;
+  const data_t lw = w - w_low;
+  const data_t hh = 1 - lh, hw = 1 - lw;
+
+  const int w_stride = nheads * channels;
+  const int h_stride = width * w_stride;
+  const int h_low_ptr_offset = h_low * h_stride;
+  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+  const int w_low_ptr_offset = w_low * w_stride;
+  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+  const int base_ptr = m * channels + c;
+
+  data_t v1 = 0;
+  if (h_low >= 0 && w_low >= 0) {
+    const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+    v1 = bottom_data[ptr1];
+  }
+  data_t v2 = 0;
+  if (h_low >= 0 && w_high <= width - 1) {
+    const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+    v2 = bottom_data[ptr2];
+  }
+  data_t v3 = 0;
+  if (h_high <= height - 1 && w_low >= 0) {
+    const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+    v3 = bottom_data[ptr3];
+  }
+  data_t v4 = 0;
+  if (h_high <= height - 1 && w_high <= width - 1) {
+    const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+    v4 = bottom_data[ptr4];
+  }
+
+  const data_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+  const data_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+  return val;
+}
+
+// forward kernel
+template <typename data_t>
+__global__ void deformable_attn_cuda_kernel_forward(
+    const int n, const data_t *data_value, const int64_t *data_spatial_shapes,
+    const int64_t *data_level_start_index, const data_t *data_sampling_loc,
+    const data_t *data_attn_weight, const int batch_size,
+    const int value_length, const int num_heads, const int channels,
+    const int num_levels, const int query_length, const int num_points,
+    data_t *output_data_ptr) {
+  CUDA_KERNEL_LOOP(index, n) {
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp;
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % query_length;
+    _temp /= query_length;
+    const int b_col = _temp;
+
+    data_t *data_ptr = output_data_ptr + index;
+    int data_weight_ptr = sampling_index * num_levels * num_points;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
+    data_t col = 0;
+
+    for (int l_col = 0; l_col < num_levels; ++l_col) {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const data_t *data_value_ptr = data_value + (data_value_ptr_init_offset +
+                                                   level_start_id * qid_stride);
+      for (int p_col = 0; p_col < num_points; ++p_col) {
+        const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const data_t weight = data_attn_weight[data_weight_ptr];
+
+        const data_t h_im = loc_h * spatial_h - 0.5;
+        const data_t w_im = loc_w * spatial_w - 0.5;
+
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+          col += deformable_attn_bilinear_forward(
+                     data_value_ptr, spatial_h, spatial_w, num_heads, channels,
+                     h_im, w_im, m_col, c_col) *
+                 weight;
+        }
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+      }
+    }
+    *data_ptr = col;
+  }
+}
+
+#define CHECK_INPUT_GPU(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
+// forward
+std::vector<paddle::Tensor>
+MSDeformableAttnCUDAForward(const paddle::Tensor &value,
+                            const paddle::Tensor &value_spatial_shapes,
+                            const paddle::Tensor &value_level_start_index,
+                            const paddle::Tensor &sampling_locations,
+                            const paddle::Tensor &attention_weights) {
+
+  CHECK_INPUT_GPU(value);
+  CHECK_INPUT_GPU(value_spatial_shapes);
+  CHECK_INPUT_GPU(value_level_start_index);
+  CHECK_INPUT_GPU(sampling_locations);
+  CHECK_INPUT_GPU(attention_weights);
+
+  const int batch_size = value.shape()[0];
+  const int value_length = value.shape()[1];
+  const int num_heads = value.shape()[2];
+  const int channels = value.shape()[3];
+
+  const int num_levels = value_spatial_shapes.shape()[0];
+  const int query_length = sampling_locations.shape()[1];
+  const int num_points = sampling_locations.shape()[4];
+
+  auto output = paddle::full({batch_size, query_length, num_heads * channels},
+                             0, value.dtype(), paddle::GPUPlace());
+
+  const int num_kernels = batch_size * query_length * num_heads * channels;
+  deformable_attn_cuda_kernel_forward<float>
+      <<<GET_BLOCKS(num_kernels, CUDA_NUM_THREADS), CUDA_NUM_THREADS, 0,
+         value.stream()>>>(num_kernels, value.data<float>(),
+                           value_spatial_shapes.data<int64_t>(),
+                           value_level_start_index.data<int64_t>(),
+                           sampling_locations.data<float>(),
+                           attention_weights.data<float>(), batch_size,
+                           value_length, num_heads, channels, num_levels,
+                           query_length, num_points, output.data<float>());
+  return {output};
+}
+
+// backward bilinear
+template <typename data_t>
+__device__ void deformable_attn_bilinear_backward(
+    const data_t *&bottom_data, const int &height, const int &width,
+    const int &nheads, const int &channels, const data_t &h, const data_t &w,
+    const int &m, const int &c, const data_t &top_grad,
+    const data_t &attn_weight, data_t *&grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  const int h_low = floor(h);
+  const int w_low = floor(w);
+  const int h_high = h_low + 1;
+  const int w_high = w_low + 1;
+
+  const data_t lh = h - h_low;
+  const data_t lw = w - w_low;
+  const data_t hh = 1 - lh, hw = 1 - lw;
+
+  const int w_stride = nheads * channels;
+  const int h_stride = width * w_stride;
+  const int h_low_ptr_offset = h_low * h_stride;
+  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+  const int w_low_ptr_offset = w_low * w_stride;
+  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+  const int base_ptr = m * channels + c;
+
+  const data_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+  const data_t top_grad_value = top_grad * attn_weight;
+  data_t grad_h_weight = 0, grad_w_weight = 0;
+
+  data_t v1 = 0;
+  if (h_low >= 0 && w_low >= 0) {
+    const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+    v1 = bottom_data[ptr1];
+    grad_h_weight -= hw * v1;
+    grad_w_weight -= hh * v1;
+    atomicAdd(grad_value + ptr1, w1 * top_grad_value);
+  }
+  data_t v2 = 0;
+  if (h_low >= 0 && w_high <= width - 1) {
+    const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+    v2 = bottom_data[ptr2];
+    grad_h_weight -= lw * v2;
+    grad_w_weight += hh * v2;
+    atomicAdd(grad_value + ptr2, w2 * top_grad_value);
+  }
+  data_t v3 = 0;
+  if (h_high <= height - 1 && w_low >= 0) {
+    const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+    v3 = bottom_data[ptr3];
+    grad_h_weight += hw * v3;
+    grad_w_weight -= lh * v3;
+    atomicAdd(grad_value + ptr3, w3 * top_grad_value);
+  }
+  data_t v4 = 0;
+  if (h_high <= height - 1 && w_high <= width - 1) {
+    const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+    v4 = bottom_data[ptr4];
+    grad_h_weight += lw * v4;
+    grad_w_weight += lh * v4;
+    atomicAdd(grad_value + ptr4, w4 * top_grad_value);
+  }
+
+  const data_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+  *grad_attn_weight = top_grad * val;
+  *grad_sampling_loc = width * grad_w_weight * top_grad_value;
+  *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
+}
+
+template <typename data_t>
+__device__ void deformable_attn_bilinear_backward_gm(
+    const data_t *&bottom_data, const int &height, const int &width,
+    const int &nheads, const int &channels, const data_t &h, const data_t &w,
+    const int &m, const int &c, const data_t &top_grad,
+    const data_t &attn_weight, data_t *&grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  const int h_low = floor(h);
+  const int w_low = floor(w);
+  const int h_high = h_low + 1;
+  const int w_high = w_low + 1;
+
+  const data_t lh = h - h_low;
+  const data_t lw = w - w_low;
+  const data_t hh = 1 - lh, hw = 1 - lw;
+
+  const int w_stride = nheads * channels;
+  const int h_stride = width * w_stride;
+  const int h_low_ptr_offset = h_low * h_stride;
+  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+  const int w_low_ptr_offset = w_low * w_stride;
+  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+  const int base_ptr = m * channels + c;
+
+  const data_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+  const data_t top_grad_value = top_grad * attn_weight;
+  data_t grad_h_weight = 0, grad_w_weight = 0;
+
+  data_t v1 = 0;
+  if (h_low >= 0 && w_low >= 0) {
+    const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+    v1 = bottom_data[ptr1];
+    grad_h_weight -= hw * v1;
+    grad_w_weight -= hh * v1;
+    atomicAdd(grad_value + ptr1, w1 * top_grad_value);
+  }
+  data_t v2 = 0;
+  if (h_low >= 0 && w_high <= width - 1) {
+    const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+    v2 = bottom_data[ptr2];
+    grad_h_weight -= lw * v2;
+    grad_w_weight += hh * v2;
+    atomicAdd(grad_value + ptr2, w2 * top_grad_value);
+  }
+  data_t v3 = 0;
+  if (h_high <= height - 1 && w_low >= 0) {
+    const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+    v3 = bottom_data[ptr3];
+    grad_h_weight += hw * v3;
+    grad_w_weight -= lh * v3;
+    atomicAdd(grad_value + ptr3, w3 * top_grad_value);
+  }
+  data_t v4 = 0;
+  if (h_high <= height - 1 && w_high <= width - 1) {
+    const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+    v4 = bottom_data[ptr4];
+    grad_h_weight += lw * v4;
+    grad_w_weight += lh * v4;
+    atomicAdd(grad_value + ptr4, w4 * top_grad_value);
+  }
+
+  const data_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+  atomicAdd(grad_attn_weight, top_grad * val);
+  atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
+  atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
+}
+
+// backward kernels
+// channels > 1024
+template <typename data_t>
+__global__ void deformable_attn_cuda_kernel_backward_shm_reduce_v2_multi_blocks(
+    const int n, const data_t *grad_col, const data_t *data_value,
+    const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+    const data_t *data_sampling_loc, const data_t *data_attn_weight,
+    const int batch_size, const int value_length, const int num_heads,
+    const int channels, const int num_levels, const int query_length,
+    const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  CUDA_KERNEL_LOOP(index, n) {
+    extern __shared__ int _s[];
+    data_t *cache_grad_sampling_loc = (data_t *)_s;
+    data_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+    unsigned int tid = threadIdx.x;
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp;
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % query_length;
+    _temp /= query_length;
+    const int b_col = _temp;
+
+    const data_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_points;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
+
+    for (int l_col = 0; l_col < num_levels; ++l_col) {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset =
+          data_value_ptr_init_offset + level_start_id * qid_stride;
+      const data_t *data_value_ptr = data_value + value_ptr_offset;
+      data_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col = 0; p_col < num_points; ++p_col) {
+        const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const data_t weight = data_attn_weight[data_weight_ptr];
+
+        const data_t h_im = loc_h * spatial_h - 0.5;
+        const data_t w_im = loc_w * spatial_w - 0.5;
+        *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
+        *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
+        *(cache_grad_attn_weight + threadIdx.x) = 0;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+          deformable_attn_bilinear_backward(
+              data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+              w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+              cache_grad_sampling_loc + (threadIdx.x << 1),
+              cache_grad_attn_weight + threadIdx.x);
+        }
+
+        __syncthreads();
+
+        for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
+             s >>= 1, spre >>= 1) {
+          if (tid < s) {
+            const unsigned int xid1 = tid << 1;
+            const unsigned int xid2 = (tid + s) << 1;
+            cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+            cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+            cache_grad_sampling_loc[xid1 + 1] +=
+                cache_grad_sampling_loc[xid2 + 1];
+            if (tid + (s << 1) < spre) {
+              cache_grad_attn_weight[tid] +=
+                  cache_grad_attn_weight[tid + (s << 1)];
+              cache_grad_sampling_loc[xid1] +=
+                  cache_grad_sampling_loc[xid2 + (s << 1)];
+              cache_grad_sampling_loc[xid1 + 1] +=
+                  cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+            }
+          }
+          __syncthreads();
+        }
+
+        if (tid == 0) {
+          atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
+          atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
+          atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
+        }
+        __syncthreads();
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+template <typename data_t>
+__global__ void deformable_attn_cuda_kernel_backward_gm(
+    const int n, const data_t *grad_col, const data_t *data_value,
+    const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+    const data_t *data_sampling_loc, const data_t *data_attn_weight,
+    const int batch_size, const int value_length, const int num_heads,
+    const int channels, const int num_levels, const int query_length,
+    const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  CUDA_KERNEL_LOOP(index, n) {
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp;
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % query_length;
+    _temp /= query_length;
+    const int b_col = _temp;
+
+    const data_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_points;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
+
+    for (int l_col = 0; l_col < num_levels; ++l_col) {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset =
+          data_value_ptr_init_offset + level_start_id * qid_stride;
+      const data_t *data_value_ptr = data_value + value_ptr_offset;
+      data_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col = 0; p_col < num_points; ++p_col) {
+        const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const data_t weight = data_attn_weight[data_weight_ptr];
+
+        const data_t h_im = loc_h * spatial_h - 0.5;
+        const data_t w_im = loc_w * spatial_w - 0.5;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+          deformable_attn_bilinear_backward_gm(
+              data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+              w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+              grad_sampling_loc, grad_attn_weight);
+        }
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+// channels <= 1024
+template <typename data_t, unsigned int blockSize>
+__global__ void
+deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1(
+    const int n, const data_t *grad_col, const data_t *data_value,
+    const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+    const data_t *data_sampling_loc, const data_t *data_attn_weight,
+    const int batch_size, const int value_length, const int num_heads,
+    const int channels, const int num_levels, const int query_length,
+    const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  CUDA_KERNEL_LOOP(index, n) {
+    __shared__ data_t cache_grad_sampling_loc[blockSize * 2];
+    __shared__ data_t cache_grad_attn_weight[blockSize];
+    unsigned int tid = threadIdx.x;
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp;
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % query_length;
+    _temp /= query_length;
+    const int b_col = _temp;
+
+    const data_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_points;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
+
+    for (int l_col = 0; l_col < num_levels; ++l_col) {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset =
+          data_value_ptr_init_offset + level_start_id * qid_stride;
+      const data_t *data_value_ptr = data_value + value_ptr_offset;
+      data_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col = 0; p_col < num_points; ++p_col) {
+        const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const data_t weight = data_attn_weight[data_weight_ptr];
+
+        const data_t h_im = loc_h * spatial_h - 0.5;
+        const data_t w_im = loc_w * spatial_w - 0.5;
+        *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
+        *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
+        *(cache_grad_attn_weight + threadIdx.x) = 0;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+          deformable_attn_bilinear_backward(
+              data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+              w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+              cache_grad_sampling_loc + (threadIdx.x << 1),
+              cache_grad_attn_weight + threadIdx.x);
+        }
+
+        __syncthreads();
+        if (tid == 0) {
+          data_t _grad_w = cache_grad_sampling_loc[0],
+                 _grad_h = cache_grad_sampling_loc[1],
+                 _grad_a = cache_grad_attn_weight[0];
+          int sid = 2;
+          for (unsigned int tid = 1; tid < blockSize; ++tid) {
+            _grad_w += cache_grad_sampling_loc[sid];
+            _grad_h += cache_grad_sampling_loc[sid + 1];
+            _grad_a += cache_grad_attn_weight[tid];
+            sid += 2;
+          }
+
+          *grad_sampling_loc = _grad_w;
+          *(grad_sampling_loc + 1) = _grad_h;
+          *grad_attn_weight = _grad_a;
+        }
+        __syncthreads();
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+template <typename data_t, unsigned int blockSize>
+__global__ void
+deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2(
+    const int n, const data_t *grad_col, const data_t *data_value,
+    const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+    const data_t *data_sampling_loc, const data_t *data_attn_weight,
+    const int batch_size, const int value_length, const int num_heads,
+    const int channels, const int num_levels, const int query_length,
+    const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  CUDA_KERNEL_LOOP(index, n) {
+    __shared__ data_t cache_grad_sampling_loc[blockSize * 2];
+    __shared__ data_t cache_grad_attn_weight[blockSize];
+    unsigned int tid = threadIdx.x;
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp;
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % query_length;
+    _temp /= query_length;
+    const int b_col = _temp;
+
+    const data_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_points;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
+
+    for (int l_col = 0; l_col < num_levels; ++l_col) {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset =
+          data_value_ptr_init_offset + level_start_id * qid_stride;
+      const data_t *data_value_ptr = data_value + value_ptr_offset;
+      data_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col = 0; p_col < num_points; ++p_col) {
+        const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const data_t weight = data_attn_weight[data_weight_ptr];
+
+        const data_t h_im = loc_h * spatial_h - 0.5;
+        const data_t w_im = loc_w * spatial_w - 0.5;
+        *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
+        *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
+        *(cache_grad_attn_weight + threadIdx.x) = 0;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+          deformable_attn_bilinear_backward(
+              data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+              w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+              cache_grad_sampling_loc + (threadIdx.x << 1),
+              cache_grad_attn_weight + threadIdx.x);
+        }
+
+        __syncthreads();
+
+        for (unsigned int s = blockSize / 2; s > 0; s >>= 1) {
+          if (tid < s) {
+            const unsigned int xid1 = tid << 1;
+            const unsigned int xid2 = (tid + s) << 1;
+            cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+            cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+            cache_grad_sampling_loc[xid1 + 1] +=
+                cache_grad_sampling_loc[xid2 + 1];
+          }
+          __syncthreads();
+        }
+
+        if (tid == 0) {
+          *grad_sampling_loc = cache_grad_sampling_loc[0];
+          *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+          *grad_attn_weight = cache_grad_attn_weight[0];
+        }
+        __syncthreads();
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+template <typename data_t>
+__global__ void deformable_attn_cuda_kernel_backward_shm_reduce_v1(
+    const int n, const data_t *grad_col, const data_t *data_value,
+    const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+    const data_t *data_sampling_loc, const data_t *data_attn_weight,
+    const int batch_size, const int value_length, const int num_heads,
+    const int channels, const int num_levels, const int query_length,
+    const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  CUDA_KERNEL_LOOP(index, n) {
+    extern __shared__ int _s[];
+    data_t *cache_grad_sampling_loc = (data_t *)_s;
+    data_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+    unsigned int tid = threadIdx.x;
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp;
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % query_length;
+    _temp /= query_length;
+    const int b_col = _temp;
+
+    const data_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_points;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
+
+    for (int l_col = 0; l_col < num_levels; ++l_col) {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset =
+          data_value_ptr_init_offset + level_start_id * qid_stride;
+      const data_t *data_value_ptr = data_value + value_ptr_offset;
+      data_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col = 0; p_col < num_points; ++p_col) {
+        const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const data_t weight = data_attn_weight[data_weight_ptr];
+
+        const data_t h_im = loc_h * spatial_h - 0.5;
+        const data_t w_im = loc_w * spatial_w - 0.5;
+        *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
+        *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
+        *(cache_grad_attn_weight + threadIdx.x) = 0;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+          deformable_attn_bilinear_backward(
+              data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+              w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+              cache_grad_sampling_loc + (threadIdx.x << 1),
+              cache_grad_attn_weight + threadIdx.x);
+        }
+
+        __syncthreads();
+        if (tid == 0) {
+          data_t _grad_w = cache_grad_sampling_loc[0],
+                 _grad_h = cache_grad_sampling_loc[1],
+                 _grad_a = cache_grad_attn_weight[0];
+          int sid = 2;
+          for (unsigned int tid = 1; tid < blockDim.x; ++tid) {
+            _grad_w += cache_grad_sampling_loc[sid];
+            _grad_h += cache_grad_sampling_loc[sid + 1];
+            _grad_a += cache_grad_attn_weight[tid];
+            sid += 2;
+          }
+
+          *grad_sampling_loc = _grad_w;
+          *(grad_sampling_loc + 1) = _grad_h;
+          *grad_attn_weight = _grad_a;
+        }
+        __syncthreads();
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+template <typename data_t>
+__global__ void deformable_attn_cuda_kernel_backward_shm_reduce_v2(
+    const int n, const data_t *grad_col, const data_t *data_value,
+    const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+    const data_t *data_sampling_loc, const data_t *data_attn_weight,
+    const int batch_size, const int value_length, const int num_heads,
+    const int channels, const int num_levels, const int query_length,
+    const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  CUDA_KERNEL_LOOP(index, n) {
+    extern __shared__ int _s[];
+    data_t *cache_grad_sampling_loc = (data_t *)_s;
+    data_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+    unsigned int tid = threadIdx.x;
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp;
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % query_length;
+    _temp /= query_length;
+    const int b_col = _temp;
+
+    const data_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_points;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
+
+    for (int l_col = 0; l_col < num_levels; ++l_col) {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset =
+          data_value_ptr_init_offset + level_start_id * qid_stride;
+      const data_t *data_value_ptr = data_value + value_ptr_offset;
+      data_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col = 0; p_col < num_points; ++p_col) {
+        const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const data_t weight = data_attn_weight[data_weight_ptr];
+
+        const data_t h_im = loc_h * spatial_h - 0.5;
+        const data_t w_im = loc_w * spatial_w - 0.5;
+        *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
+        *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
+        *(cache_grad_attn_weight + threadIdx.x) = 0;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+          deformable_attn_bilinear_backward(
+              data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+              w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+              cache_grad_sampling_loc + (threadIdx.x << 1),
+              cache_grad_attn_weight + threadIdx.x);
+        }
+
+        __syncthreads();
+
+        for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
+             s >>= 1, spre >>= 1) {
+          if (tid < s) {
+            const unsigned int xid1 = tid << 1;
+            const unsigned int xid2 = (tid + s) << 1;
+            cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+            cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+            cache_grad_sampling_loc[xid1 + 1] +=
+                cache_grad_sampling_loc[xid2 + 1];
+            if (tid + (s << 1) < spre) {
+              cache_grad_attn_weight[tid] +=
+                  cache_grad_attn_weight[tid + (s << 1)];
+              cache_grad_sampling_loc[xid1] +=
+                  cache_grad_sampling_loc[xid2 + (s << 1)];
+              cache_grad_sampling_loc[xid1 + 1] +=
+                  cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+            }
+          }
+          __syncthreads();
+        }
+
+        if (tid == 0) {
+          *grad_sampling_loc = cache_grad_sampling_loc[0];
+          *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+          *grad_attn_weight = cache_grad_attn_weight[0];
+        }
+        __syncthreads();
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+// backward branch
+template <typename data_t>
+void deformable_attn_cuda_backward(
+    cudaStream_t stream, const data_t *grad_out, const data_t *data_value,
+    const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+    const data_t *data_sampling_loc, const data_t *data_attn_weight,
+    const int batch_size, const int value_length, const int num_heads,
+    const int channels, const int num_levels, const int query_length,
+    const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  const int num_threads =
+      (channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels;
+  const int num_kernels = batch_size * query_length * num_heads * channels;
+  const int num_actual_kernels =
+      batch_size * query_length * num_heads * channels;
+  if (channels > 1024) {
+    if ((channels & 1023) == 0) {
+      deformable_attn_cuda_kernel_backward_shm_reduce_v2_multi_blocks<data_t>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+             num_threads * 3 * sizeof(data_t), stream>>>(
+              num_kernels, grad_out, data_value, data_spatial_shapes,
+              data_level_start_index, data_sampling_loc, data_attn_weight,
+              batch_size, value_length, num_heads, channels, num_levels,
+              query_length, num_points, grad_value, grad_sampling_loc,
+              grad_attn_weight);
+    } else {
+      deformable_attn_cuda_kernel_backward_gm<data_t>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+    }
+  } else {
+    switch (channels) {
+    case 1:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
+                                                                         1>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 2:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
+                                                                         2>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 4:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
+                                                                         4>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 8:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
+                                                                         8>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 16:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
+                                                                         16>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 32:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
+                                                                         32>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 64:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2<data_t,
+                                                                         64>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 128:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2<data_t,
+                                                                         128>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 256:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2<data_t,
+                                                                         256>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 512:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2<data_t,
+                                                                         512>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 1024:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2<data_t,
+                                                                         1024>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    default:
+      if (channels < 64) {
+        deformable_attn_cuda_kernel_backward_shm_reduce_v1<data_t>
+            <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+               num_threads * 3 * sizeof(data_t), stream>>>(
+                num_kernels, grad_out, data_value, data_spatial_shapes,
+                data_level_start_index, data_sampling_loc, data_attn_weight,
+                batch_size, value_length, num_heads, channels, num_levels,
+                query_length, num_points, grad_value, grad_sampling_loc,
+                grad_attn_weight);
+      } else {
+        deformable_attn_cuda_kernel_backward_shm_reduce_v2<data_t>
+            <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+               num_threads * 3 * sizeof(data_t), stream>>>(
+                num_kernels, grad_out, data_value, data_spatial_shapes,
+                data_level_start_index, data_sampling_loc, data_attn_weight,
+                batch_size, value_length, num_heads, channels, num_levels,
+                query_length, num_points, grad_value, grad_sampling_loc,
+                grad_attn_weight);
+      }
+    }
+  }
+}
+
+// backward
+std::vector<paddle::Tensor> MSDeformableAttnCUDABackward(
+    const paddle::Tensor &value, const paddle::Tensor &value_spatial_shapes,
+    const paddle::Tensor &value_level_start_index,
+    const paddle::Tensor &sampling_locations,
+    const paddle::Tensor &attention_weights, const paddle::Tensor &grad_out) {
+
+  CHECK_INPUT_GPU(value);
+  CHECK_INPUT_GPU(value_spatial_shapes);
+  CHECK_INPUT_GPU(value_level_start_index);
+  CHECK_INPUT_GPU(sampling_locations);
+  CHECK_INPUT_GPU(attention_weights);
+  CHECK_INPUT_GPU(grad_out);
+
+  const int batch_size = value.shape()[0];
+  const int value_length = value.shape()[1];
+  const int num_heads = value.shape()[2];
+  const int channels = value.shape()[3];
+
+  const int num_levels = value_spatial_shapes.shape()[0];
+  const int query_length = sampling_locations.shape()[1];
+  const int num_points = sampling_locations.shape()[4];
+
+  auto grad_value =
+      paddle::full(value.shape(), 0, value.dtype(), paddle::GPUPlace());
+  auto grad_spatial_shapes =
+      paddle::full(value.shape(), 0, value.dtype(), paddle::GPUPlace());
+  auto grad_level_start_index =
+      paddle::full(value.shape(), 0, value.dtype(), paddle::GPUPlace());
+  auto grad_sampling_locations =
+      paddle::full(sampling_locations.shape(), 0, sampling_locations.dtype(),
+                   paddle::GPUPlace());
+  auto grad_attention_weights =
+      paddle::full(attention_weights.shape(), 0, attention_weights.dtype(),
+                   paddle::GPUPlace());
+
+  deformable_attn_cuda_backward<float>(
+      value.stream(), grad_out.data<float>(), value.data<float>(),
+      value_spatial_shapes.data<int64_t>(),
+      value_level_start_index.data<int64_t>(), sampling_locations.data<float>(),
+      attention_weights.data<float>(), batch_size, value_length, num_heads,
+      channels, num_levels, query_length, num_points, grad_value.data<float>(),
+      grad_sampling_locations.data<float>(),
+      grad_attention_weights.data<float>());
+
+  return {grad_value, grad_spatial_shapes, grad_level_start_index,
+          grad_sampling_locations, grad_attention_weights};
+}

+ 7 - 0
models/detectors/rtdetr/basic_modules/ext_op/setup_ms_deformable_attn_op.py

@@ -0,0 +1,7 @@
+from paddle.utils.cpp_extension import CUDAExtension, setup
+
+if __name__ == "__main__":
+    setup(
+        name='deformable_detr_ops',
+        ext_modules=CUDAExtension(
+            sources=['ms_deformable_attn_op.cc', 'ms_deformable_attn_op.cu']))

+ 140 - 0
models/detectors/rtdetr/basic_modules/ext_op/test_ms_deformable_attn_op.py

@@ -0,0 +1,140 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import os
+import sys
+import random
+import numpy as np
+import paddle
+# add python path of PaddleDetection to sys.path
+parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 5)))
+if parent_path not in sys.path:
+    sys.path.append(parent_path)
+
+from ppdet.modeling.transformers.utils import deformable_attention_core_func
+ms_deform_attn_core_paddle = deformable_attention_core_func
+
+try:
+    gpu_index = int(sys.argv[1])
+except:
+    gpu_index = 0
+print(f'Use gpu {gpu_index} to test...')
+paddle.set_device(f'gpu:{gpu_index}')
+
+try:
+    from deformable_detr_ops import ms_deformable_attn
+except Exception as e:
+    print('import deformable_detr_ops error', e)
+    sys.exit(-1)
+
+paddle.seed(1)
+random.seed(1)
+np.random.seed(1)
+
+bs, n_heads, c = 2, 8, 8
+query_length, n_levels, n_points = 2, 2, 2
+spatial_shapes = paddle.to_tensor([(6, 4), (3, 2)], dtype=paddle.int64)
+level_start_index = paddle.concat((paddle.to_tensor(
+    [0], dtype=paddle.int64), spatial_shapes.prod(1).cumsum(0)[:-1]))
+value_length = sum([(H * W).item() for H, W in spatial_shapes])
+
+
+def get_test_tensors(channels):
+    value = paddle.rand(
+        [bs, value_length, n_heads, channels], dtype=paddle.float32) * 0.01
+    sampling_locations = paddle.rand(
+        [bs, query_length, n_heads, n_levels, n_points, 2],
+        dtype=paddle.float32)
+    attention_weights = paddle.rand(
+        [bs, query_length, n_heads, n_levels, n_points],
+        dtype=paddle.float32) + 1e-5
+    attention_weights /= attention_weights.sum(-1, keepdim=True).sum(
+        -2, keepdim=True)
+
+    return [value, sampling_locations, attention_weights]
+
+
+@paddle.no_grad()
+def check_forward_equal_with_paddle_float():
+    value, sampling_locations, attention_weights = get_test_tensors(c)
+
+    output_paddle = ms_deform_attn_core_paddle(
+        value, spatial_shapes, level_start_index, sampling_locations,
+        attention_weights).detach().cpu()
+    output_cuda = ms_deformable_attn(value, spatial_shapes, level_start_index,
+                                     sampling_locations,
+                                     attention_weights).detach().cpu()
+    fwdok = paddle.allclose(
+        output_cuda, output_paddle, rtol=1e-2, atol=1e-3).item()
+    max_abs_err = (output_cuda - output_paddle).abs().max().item()
+    max_rel_err = (
+        (output_cuda - output_paddle).abs() / output_paddle.abs()).max().item()
+
+    print(
+        f'*{fwdok} check_forward_equal_with_paddle_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}'
+    )
+
+
+def check_gradient_numerical(channels=4):
+    value_paddle, sampling_locations_paddle, attention_weights_paddle = get_test_tensors(
+        channels)
+    value_paddle.stop_gradient = False
+    sampling_locations_paddle.stop_gradient = False
+    attention_weights_paddle.stop_gradient = False
+
+    value_cuda = value_paddle.detach().clone()
+    sampling_locations_cuda = sampling_locations_paddle.detach().clone()
+    attention_weights_cuda = attention_weights_paddle.detach().clone()
+    value_cuda.stop_gradient = False
+    sampling_locations_cuda.stop_gradient = False
+    attention_weights_cuda.stop_gradient = False
+
+    output_paddle = ms_deform_attn_core_paddle(
+        value_paddle, spatial_shapes, level_start_index,
+        sampling_locations_paddle, attention_weights_paddle)
+    output_paddle.sum().backward()
+
+    output_cuda = ms_deformable_attn(value_cuda, spatial_shapes,
+                                     level_start_index, sampling_locations_cuda,
+                                     attention_weights_cuda)
+    output_cuda.sum().backward()
+
+    res = paddle.allclose(
+        value_paddle.grad, value_cuda.grad, rtol=1e-2, atol=1e-3).item()
+    print(f'*tensor1 {res} check_gradient_numerical(D={channels})')
+
+    res = paddle.allclose(
+        sampling_locations_paddle.grad,
+        sampling_locations_cuda.grad,
+        rtol=1e-2,
+        atol=1e-3).item()
+    print(f'*tensor2 {res} check_gradient_numerical(D={channels})')
+
+    res = paddle.allclose(
+        attention_weights_paddle.grad,
+        attention_weights_cuda.grad,
+        rtol=1e-2,
+        atol=1e-3).item()
+    print(f'*tensor3 {res} check_gradient_numerical(D={channels})')
+
+
+if __name__ == '__main__':
+    check_forward_equal_with_paddle_float()
+
+    for channels in [30, 32, 64, 71, 128, 1024, 1025, 2048, 3096]:
+        check_gradient_numerical(channels)

+ 6 - 3
models/detectors/rtdetr/basic_modules/fpn.py

@@ -4,9 +4,11 @@ import torch.nn.functional as F
 from typing import List
 
 try:
-    from .basic import get_clones, BasicConv, RTCBlock, TransformerEncoder
+    from .basic import BasicConv, RTCBlock
+    from .transformer import TransformerEncoder
 except:
-    from  basic import get_clones, BasicConv, RTCBlock, TransformerEncoder
+    from  basic import BasicConv, RTCBlock
+    from  transformer import TransformerEncoder
 
 
 # Build PaFPN
@@ -34,7 +36,7 @@ def build_fpn(cfg, in_dims, out_dim):
 ## Hybrid Encoder (Transformer encoder + Convolutional PaFPN)
 class HybridEncoder(nn.Module):
     def __init__(self, 
-                 in_dims     :List  = [256, 512, 512],
+                 in_dims     :List  = [256, 512, 1024],
                  out_dim     :int   = 256,
                  width       :float = 1.0,
                  depth       :float = 1.0,
@@ -55,6 +57,7 @@ class HybridEncoder(nn.Module):
         # ---------------- Basic parameters ----------------
         self.in_dims = in_dims
         self.out_dim = round(out_dim * width)
+        self.out_dims = [self.out_dim] * len(in_dims)
         self.width = width
         self.depth = depth
         self.num_heads = num_heads

+ 492 - 0
models/detectors/rtdetr/basic_modules/transformer.py

@@ -0,0 +1,492 @@
+import math
+import copy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.init import constant_, xavier_uniform_
+
+try:
+    from .basic import get_activation
+except:
+    from  basic import get_activation
+
+
+def get_clones(module, N):
+    if N <= 0:
+        return None
+    else:
+        return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
+
+def inverse_sigmoid(x, eps=1e-5):
+    x = x.clamp(min=0., max=1.)
+    return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
+
+
+# ----------------- MLP modules -----------------
+class MLP(nn.Module):
+    def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
+        super().__init__()
+        self.num_layers = num_layers
+        h = [hidden_dim] * (num_layers - 1)
+        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([in_dim] + h, h + [out_dim]))
+
+    def forward(self, x):
+        for i, layer in enumerate(self.layers):
+            x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+        return x
+
+class FFN(nn.Module):
+    def __init__(self, d_model=256, mlp_ratio=4.0, dropout=0., act_type='relu'):
+        super().__init__()
+        self.fpn_dim = round(d_model * mlp_ratio)
+        self.linear1 = nn.Linear(d_model, self.fpn_dim)
+        self.activation = get_activation(act_type)
+        self.dropout2 = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(self.fpn_dim, d_model)
+        self.dropout3 = nn.Dropout(dropout)
+        self.norm = nn.LayerNorm(d_model)
+
+    def forward(self, src):
+        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
+        src = src + self.dropout3(src2)
+        src = self.norm(src)
+        
+        return src
+    
+
+# ----------------- Basic Transformer Ops -----------------
+def multi_scale_deformable_attn_pytorch(
+    value: torch.Tensor,
+    value_spatial_shapes: torch.Tensor,
+    sampling_locations: torch.Tensor,
+    attention_weights: torch.Tensor,
+) -> torch.Tensor:
+
+    bs, _, num_heads, embed_dims = value.shape
+    _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
+    
+    value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
+    sampling_grids = 2 * sampling_locations - 1
+    sampling_value_list = []
+    for level, (H_, W_) in enumerate(value_spatial_shapes):
+        # bs, H_*W_, num_heads, embed_dims ->
+        # bs, H_*W_, num_heads*embed_dims ->
+        # bs, num_heads*embed_dims, H_*W_ ->
+        # bs*num_heads, embed_dims, H_, W_
+        value_l_ = (
+            value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
+        )
+        # bs, num_queries, num_heads, num_points, 2 ->
+        # bs, num_heads, num_queries, num_points, 2 ->
+        # bs*num_heads, num_queries, num_points, 2
+        sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
+        # bs*num_heads, embed_dims, num_queries, num_points
+        sampling_value_l_ = F.grid_sample(
+            value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
+        )
+        sampling_value_list.append(sampling_value_l_)
+    # (bs, num_queries, num_heads, num_levels, num_points) ->
+    # (bs, num_heads, num_queries, num_levels, num_points) ->
+    # (bs, num_heads, 1, num_queries, num_levels*num_points)
+    attention_weights = attention_weights.transpose(1, 2).reshape(
+        bs * num_heads, 1, num_queries, num_levels * num_points
+    )
+    output = (
+        (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
+        .sum(-1)
+        .view(bs, num_heads * embed_dims, num_queries)
+    )
+    return output.transpose(1, 2).contiguous()
+
+class MSDeformableAttention(nn.Module):
+    def __init__(self,
+                 embed_dim=256,
+                 num_heads=8,
+                 num_levels=4,
+                 num_points=4):
+        """
+        Multi-Scale Deformable Attention Module
+        """
+        super(MSDeformableAttention, self).__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.num_levels = num_levels
+        self.num_points = num_points
+        self.total_points = num_heads * num_levels * num_points
+
+        self.head_dim = embed_dim // num_heads
+        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+
+        self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2)
+        self.attention_weights = nn.Linear(embed_dim, self.total_points)
+        self.value_proj = nn.Linear(embed_dim, embed_dim)
+        self.output_proj = nn.Linear(embed_dim, embed_dim)
+        
+        try:
+            # use cuda op
+            from deformable_detr_ops import ms_deformable_attn
+            self.ms_deformable_attn_core = ms_deformable_attn
+        except:
+            # use torch func
+            self.ms_deformable_attn_core = multi_scale_deformable_attn_pytorch
+
+        self._reset_parameters()
+
+    def _reset_parameters(self):
+        """
+        Default initialization for Parameters of Module.
+        """
+        constant_(self.sampling_offsets.weight.data, 0.0)
+        thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
+            2.0 * math.pi / self.num_heads
+        )
+        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+        grid_init = (
+            (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
+            .view(self.num_heads, 1, 1, 2)
+            .repeat(1, self.num_levels, self.num_points, 1)
+        )
+        for i in range(self.num_points):
+            grid_init[:, :, i, :] *= i + 1
+        with torch.no_grad():
+            self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
+        constant_(self.attention_weights.weight.data, 0.0)
+        constant_(self.attention_weights.bias.data, 0.0)
+        xavier_uniform_(self.value_proj.weight.data)
+        constant_(self.value_proj.bias.data, 0.0)
+        xavier_uniform_(self.output_proj.weight.data)
+        constant_(self.output_proj.bias.data, 0.0)
+
+    def forward(self,
+                query,
+                reference_points,
+                value,
+                value_spatial_shapes,
+                value_mask=None):
+        """
+        Args:
+            query (Tensor): [bs, query_length, C]
+            reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
+                bottom-right (1, 1), including padding area
+            value (Tensor): [bs, value_length, C]
+            value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
+            value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
+
+        Returns:
+            output (Tensor): [bs, Length_{query}, C]
+        """
+        bs, num_query = query.shape[:2]
+        num_value = value.shape[1]
+        assert sum([s[0] * s[1] for s in value_spatial_shapes]) == num_value
+
+        # Value projection
+        value = self.value_proj(value)
+        # fill "0" for the padding part
+        if value_mask is not None:
+            value_mask = value_mask.astype(value.dtype).unsqueeze(-1)
+            value *= value_mask
+        # [bs, all_hw, 256] -> [bs, all_hw, num_head, head_dim]
+        value = value.reshape([bs, num_value, self.num_heads, -1])
+
+        # [bs, all_hw, num_head, nun_level, num_sample_point, num_offset]
+        sampling_offsets = self.sampling_offsets(query).reshape(
+            [bs, num_query, self.num_heads, self.num_levels, self.num_points, 2])
+        # [bs, all_hw, num_head, nun_level*num_sample_point]
+        attention_weights = self.attention_weights(query).reshape(
+            [bs, num_query, self.num_heads, self.num_levels * self.num_points])
+        attention_weights = attention_weights.softmax(-1)
+        # [bs, all_hw, num_head, nun_level, num_sample_point]
+        attention_weights = attention_weights.reshape(
+            [bs, num_query, self.num_heads, self.num_levels, self.num_points])
+
+        # [bs, num_query, num_heads, num_levels, num_points, 2]
+        if reference_points.shape[-1] == 2:
+            # reference_points   [bs, all_hw, num_sample_point, 2] -> [bs, all_hw, 1, num_sample_point, 1, 2]
+            # sampling_offsets   [bs, all_hw, nun_head, num_level, num_sample_point, 2]
+            # offset_normalizer  [4, 2] -> [1, 1, 1, num_sample_point, 1, 2]
+            # references_points + sampling_offsets
+            offset_normalizer = value_spatial_shapes.flip([1]).reshape(
+                [1, 1, 1, self.num_levels, 1, 2])
+            sampling_locations = (
+                reference_points[:, :, None, :, None, :]
+                + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
+            )
+        elif reference_points.shape[-1] == 4:
+            sampling_locations = (
+                reference_points[:, :, None, :, None, :2]
+                + sampling_offsets
+                / self.num_points
+                * reference_points[:, :, None, :, None, 2:]
+                * 0.5)
+        else:
+            raise ValueError(
+                "Last dim of reference_points must be 2 or 4, but get {} instead.".
+                format(reference_points.shape[-1]))
+
+        # Multi-scale Deformable attention
+        output = self.ms_deformable_attn_core(
+            value, value_spatial_shapes, sampling_locations, attention_weights)
+        
+        # Output project
+        output = self.output_proj(output)
+
+        return output
+
+
+# ----------------- Transformer modules -----------------
+## Transformer Encoder layer
+class TransformerEncoderLayer(nn.Module):
+    def __init__(self,
+                 d_model         :int   = 256,
+                 num_heads       :int   = 8,
+                 mlp_ratio       :float = 4.0,
+                 dropout         :float = 0.1,
+                 act_type        :str   = "relu",
+                 ):
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.mlp_ratio = mlp_ratio
+        self.dropout = dropout
+        self.act_type = act_type
+        # ----------- Basic parameters -----------
+        # Multi-head Self-Attn
+        self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
+        self.dropout = nn.Dropout(dropout)
+        self.norm = nn.LayerNorm(d_model)
+
+        # Feedforwaed Network
+        self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
+
+    def with_pos_embed(self, tensor, pos):
+        return tensor if pos is None else tensor + pos
+
+
+    def forward(self, src, pos_embed):
+        """
+        Input:
+            src:       [torch.Tensor] -> [B, N, C]
+            pos_embed: [torch.Tensor] -> [B, N, C]
+        Output:
+            src:       [torch.Tensor] -> [B, N, C]
+        """
+        q = k = self.with_pos_embed(src, pos_embed)
+
+        # -------------- MHSA --------------
+        src2 = self.self_attn(q, k, value=src)[0]
+        src = src + self.dropout(src2)
+        src = self.norm(src)
+
+        # -------------- FFN --------------
+        src = self.ffn(src)
+        
+        return src
+
+## Transformer Encoder
+class TransformerEncoder(nn.Module):
+    def __init__(self,
+                 d_model        :int   = 256,
+                 num_heads      :int   = 8,
+                 num_layers     :int   = 1,
+                 mlp_ratio      :float = 4.0,
+                 pe_temperature : float = 10000.,
+                 dropout        :float = 0.1,
+                 act_type       :str   = "relu",
+                 ):
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.num_layers = num_layers
+        self.mlp_ratio = mlp_ratio
+        self.dropout = dropout
+        self.act_type = act_type
+        self.pe_temperature = pe_temperature
+        self.pos_embed = None
+        # ----------- Basic parameters -----------
+        self.encoder_layers = get_clones(
+            TransformerEncoderLayer(d_model, num_heads, mlp_ratio, dropout, act_type), num_layers)
+
+    def build_2d_sincos_position_embedding(self, w, h, embed_dim=256, temperature=10000.):
+        assert embed_dim % 4 == 0, \
+            'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
+        
+        # ----------- Check cahed pos_embed -----------
+        if self.pos_embed is not None and \
+            self.pos_embed.shape[2:] == [h, w]:
+            return self.pos_embed
+        
+        # ----------- Generate grid coords -----------
+        grid_w = torch.arange(int(w), dtype=torch.float32)
+        grid_h = torch.arange(int(h), dtype=torch.float32)
+        grid_w, grid_h = torch.meshgrid([grid_w, grid_h])  # shape: [H, W]
+
+        pos_dim = embed_dim // 4
+        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
+        omega = 1. / (temperature**omega)
+
+        out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
+        out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
+
+        # shape: [1, N, C]
+        pos_embed = torch.concat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], axis=1)[None, :, :]
+        self.pos_embed = pos_embed
+
+        return pos_embed
+
+    def forward(self, src):
+        """
+        Input:
+            src:       [torch.Tensor] -> [B, C, H, W]
+        Output:
+            src:       [torch.Tensor] -> [B, N, C]
+        """
+        # -------- Transformer encoder --------
+        for encoder in self.encoder_layers:
+            channels, fmp_h, fmp_w = src.shape[1:]
+            # [B, C, H, W] -> [B, N, C], N=HxW
+            src_flatten = src.flatten(2).permute(0, 2, 1)
+            pos_embed = self.build_2d_sincos_position_embedding(
+                    fmp_w, fmp_h, channels, self.pe_temperature)
+            memory = encoder(src_flatten, pos_embed=pos_embed)
+            # [B, N, C] -> [B, C, N] -> [B, C, H, W]
+            src = memory.permute(0, 2, 1).reshape([-1, channels, fmp_h, fmp_w])
+
+        return src
+
+## Transformer Decoder layer
+class DeformableTransformerDecoderLayer(nn.Module):
+    def __init__(self,
+                 d_model     :int   = 256,
+                 num_heads   :int   = 8,
+                 num_levels  :int   = 3,
+                 num_points  :int   = 4,
+                 mlp_ratio   :float = 4.0,
+                 dropout     :float = 0.1,
+                 act_type    :str   = "relu",
+                 ):
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.num_levels = num_levels
+        self.num_points = num_points
+        self.mlp_ratio = mlp_ratio
+        self.dropout = dropout
+        self.act_type = act_type
+        # ---------------- Network parameters ----------------
+        ## Multi-head Self-Attn
+        self.self_attn  = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
+        self.dropout1 = nn.Dropout(dropout)
+        self.norm1 = nn.LayerNorm(d_model)
+        ## CrossAttention
+        self.cross_attn = MSDeformableAttention(d_model, num_heads, num_levels, num_points)
+        self.dropout2 = nn.Dropout(dropout)
+        self.norm2 = nn.LayerNorm(d_model)
+        ## FFN
+        self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
+
+    def with_pos_embed(self, tensor, pos):
+        return tensor if pos is None else tensor + pos
+
+    def forward(self,
+                tgt,
+                reference_points,
+                memory,
+                memory_spatial_shapes,
+                attn_mask=None,
+                memory_mask=None,
+                query_pos_embed=None):
+        # ---------------- MSHA for Object Query -----------------
+        q = k = self.with_pos_embed(tgt, query_pos_embed)
+        if attn_mask is not None:
+            attn_mask = torch.where(
+                attn_mask.bool(),
+                torch.zeros(attn_mask.shape, dtype=tgt.dtype),
+                torch.full(attn_mask.shape, float("-inf"), dtype=tgt.dtype))
+        tgt2 = self.self_attn(q, k, value=tgt)[0]
+        tgt = tgt + self.dropout1(tgt2)
+        tgt = self.norm1(tgt)
+
+        # ---------------- CMHA for Object Query and Image-feature -----------------
+        tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos_embed),
+                               reference_points,
+                               memory,
+                               memory_spatial_shapes,
+                               memory_mask)
+        tgt = tgt + self.dropout2(tgt2)
+        tgt = self.norm2(tgt)
+
+        # ---------------- FeedForward Network -----------------
+        tgt = self.ffn(tgt)
+
+        return tgt
+
+## Transformer Decoder
+class DeformableTransformerDecoder(nn.Module):
+    def __init__(self,
+                 d_model        :int   = 256,
+                 num_heads      :int   = 8,
+                 num_layers     :int   = 1,
+                 num_levels     :int   = 3,
+                 num_points     :int   = 4,
+                 mlp_ratio      :float = 4.0,
+                 dropout        :float = 0.1,
+                 act_type       :str   = "relu",
+                 return_intermediate :bool = False,
+                 ):
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.num_layers = num_layers
+        self.mlp_ratio = mlp_ratio
+        self.dropout = dropout
+        self.act_type = act_type
+        self.pos_embed = None
+        # ----------- Network parameters -----------
+        self.decoder_layers = get_clones(
+            DeformableTransformerDecoderLayer(d_model, num_heads, num_levels, num_points, mlp_ratio, dropout, act_type), num_layers)
+        self.num_layers = num_layers
+        self.return_intermediate = return_intermediate
+
+    def forward(self,
+                tgt,
+                ref_points_unact,
+                memory,
+                memory_spatial_shapes,
+                bbox_head,
+                score_head,
+                query_pos_head,
+                attn_mask=None,
+                memory_mask=None):
+        output = tgt
+        dec_out_bboxes = []
+        dec_out_logits = []
+        ref_points_detach = F.sigmoid(ref_points_unact)
+        for i, layer in enumerate(self.decoder_layers):
+            ref_points_input = ref_points_detach.unsqueeze(2)
+            query_pos_embed = query_pos_head(ref_points_detach)
+
+            output = layer(output, ref_points_input, memory,
+                           memory_spatial_shapes, attn_mask,
+                           memory_mask, query_pos_embed)
+
+            inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
+                ref_points_detach))
+
+            dec_out_logits.append(score_head[i](output))
+            if i == 0:
+                dec_out_bboxes.append(inter_ref_bbox)
+            else:
+                dec_out_bboxes.append(
+                    F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
+                        ref_points)))
+
+            ref_points = inter_ref_bbox
+            ref_points_detach = inter_ref_bbox.detach()
+
+        return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)
+

+ 34 - 0
models/detectors/rtdetr/build.py

@@ -0,0 +1,34 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+import torch
+import torch.nn as nn
+
+from .loss import build_criterion
+from .rtdetr import RT_DETR
+
+
+# build object detector
+def build_rtdetr(args, cfg, num_classes=80, trainable=False, deploy=False):
+    print('==============================')
+    print('Build {} ...'.format(args.model.upper()))
+    
+    print('==============================')
+    print('Model Configuration: \n', cfg)
+    
+    # -------------- Build RT-DETR --------------
+    model = RT_DETR(cfg             = cfg,
+                    num_classes     = num_classes,
+                    conf_thresh     = args.conf_thresh,
+                    topk            = args.topk,
+                    deploy          = deploy,
+                    no_multi_labels = args.no_multi_labels,
+                    )
+            
+    # -------------- Build criterion --------------
+    criterion = None
+    if trainable:
+        # build criterion for training
+        criterion = build_criterion(cfg, num_classes)
+        
+    return model, criterion

+ 424 - 0
models/detectors/rtdetr/loss.py

@@ -0,0 +1,424 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+try:
+    from .loss_utils import varifocal_loss_with_logits, sigmoid_focal_loss
+    from .loss_utils import box_cxcywh_to_xyxy, bbox_iou
+    from .loss_utils import is_dist_avail_and_initialized, get_world_size
+    from .loss_utils import GIoULoss
+    from .matcher import HungarianMatcher
+except:
+    from loss_utils import varifocal_loss_with_logits, sigmoid_focal_loss
+    from loss_utils import box_cxcywh_to_xyxy, bbox_iou
+    from loss_utils import is_dist_avail_and_initialized, get_world_size
+    from loss_utils import GIoULoss
+    from matcher import HungarianMatcher
+
+
+# --------------- Criterion for RT-DETR ---------------
+def build_criterion(cfg, num_classes=80):
+    return Criterion(cfg, num_classes)
+
+class Criterion(object):
+    def __init__(self, cfg, num_classes=80):
+        self.matcher = HungarianMatcher(cfg['matcher_hpy']['cost_class'],
+                                        cfg['matcher_hpy']['cost_bbox'],
+                                        cfg['matcher_hpy']['cost_giou'],
+                                        alpha=0.25,
+                                        gamma=2.0)
+        self.loss = DINOLoss(num_classes = num_classes,
+                                matcher     = self.matcher,
+                                aux_loss    = True,
+                                use_vfl     = cfg['use_vfl'],
+                                loss_coeff  = cfg['loss_coeff'])
+
+    def __call__(self, dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta, targets=None):
+        assert targets is not None
+
+        gt_labels = [t['labels'] for t in targets]  # (List[torch.Tensor]) -> List[[N,]]
+        gt_boxes  = [t['boxes']  for t in targets]  # (List[torch.Tensor]) -> List[[N, 4]]
+
+        if dn_meta is not None:
+            if isinstance(dn_meta, list):
+                dual_groups = len(dn_meta) - 1
+                dec_out_bboxes = torch.split(
+                    dec_out_bboxes, dual_groups + 1, dim=2)
+                dec_out_logits = torch.split(
+                    dec_out_logits, dual_groups + 1, dim=2)
+                enc_topk_bboxes = torch.split(
+                    enc_topk_bboxes, dual_groups + 1, dim=1)
+                enc_topk_logits = torch.split(
+                    enc_topk_logits, dual_groups + 1, dim=1)
+
+                loss = {}
+                for g_id in range(dual_groups + 1):
+                    if dn_meta[g_id] is not None:
+                        dn_out_bboxes_gid, dec_out_bboxes_gid = torch.split(
+                            dec_out_bboxes[g_id],
+                            dn_meta[g_id]['dn_num_split'],
+                            dim=2)
+                        dn_out_logits_gid, dec_out_logits_gid = torch.split(
+                            dec_out_logits[g_id],
+                            dn_meta[g_id]['dn_num_split'],
+                            dim=2)
+                    else:
+                        dn_out_bboxes_gid, dn_out_logits_gid = None, None
+                        dec_out_bboxes_gid = dec_out_bboxes[g_id]
+                        dec_out_logits_gid = dec_out_logits[g_id]
+                    out_bboxes_gid = torch.cat([
+                        enc_topk_bboxes[g_id].unsqueeze(0),
+                        dec_out_bboxes_gid
+                    ])
+                    out_logits_gid = torch.cat([
+                        enc_topk_logits[g_id].unsqueeze(0),
+                        dec_out_logits_gid
+                    ])
+                    loss_gid = self.loss(
+                        out_bboxes_gid,
+                        out_logits_gid,
+                        gt_boxes,
+                        gt_labels,
+                        dn_out_bboxes=dn_out_bboxes_gid,
+                        dn_out_logits=dn_out_logits_gid,
+                        dn_meta=dn_meta[g_id])
+                    # sum loss
+                    for key, value in loss_gid.items():
+                        loss.update({
+                            key: loss.get(key, torch.zeros([1])) + value
+                        })
+
+                # average across (dual_groups + 1)
+                for key, value in loss.items():
+                    loss.update({key: value / (dual_groups + 1)})
+                return loss
+            else:
+                dn_out_bboxes, dec_out_bboxes = torch.split(
+                    dec_out_bboxes, dn_meta['dn_num_split'], dim=2)
+                dn_out_logits, dec_out_logits = torch.split(
+                    dec_out_logits, dn_meta['dn_num_split'], dim=2)
+        else:
+            dn_out_bboxes, dn_out_logits = None, None
+
+        out_bboxes = torch.cat(
+            [enc_topk_bboxes.unsqueeze(0), dec_out_bboxes])
+        out_logits = torch.cat(
+            [enc_topk_logits.unsqueeze(0), dec_out_logits])
+
+        return self.loss(out_bboxes,
+                         out_logits,
+                         gt_boxes,
+                         gt_labels,
+                         dn_out_bboxes=dn_out_bboxes,
+                         dn_out_logits=dn_out_logits,
+                         dn_meta=dn_meta)
+
+
+# --------------- DETR series loss ---------------
+class DETRLoss(nn.Module):
+    """Modified Paddle DETRLoss class without mask loss."""
+    def __init__(self,
+                 num_classes=80,
+                 matcher='HungarianMatcher',
+                 aux_loss=True,
+                 use_vfl=False,
+                 loss_coeff={'class': 1,
+                             'bbox': 5,
+                             'giou': 2,
+                             'no_object': 0.1,},
+                 ):
+        super(DETRLoss, self).__init__()
+        self.num_classes = num_classes
+        self.matcher = matcher
+        self.loss_coeff = loss_coeff
+        self.aux_loss = aux_loss
+        self.use_vfl = use_vfl
+        self.giou_loss = GIoULoss(reduction='none')
+
+    def _get_loss_class(self,
+                        logits,
+                        gt_class,
+                        match_indices,
+                        bg_index,
+                        num_gts,
+                        postfix="",
+                        iou_score=None):
+        # logits: [b, query, num_classes], gt_class: list[[n, 1]]
+        name_class = "loss_class" + postfix
+
+        target_label = torch.full(logits.shape[:2], bg_index).long()
+        bs, num_query_objects = target_label.shape
+        num_gt = sum(len(a) for a in gt_class)
+        if num_gt > 0:
+            index, updates = self._get_index_updates(
+                num_query_objects, gt_class, match_indices)
+            target_label = target_label.reshape(-1, 1)
+            target_label[index] = updates.long()[:, None]
+            # target_label = paddle.scatter(target_label, index, updates.long())
+            target_label = target_label.reshape(bs, num_query_objects)
+
+        # one-hot label
+        target_label = F.one_hot(target_label, self.num_classes + 1)[..., :-1].float()
+        if iou_score is not None and self.use_vfl:
+            target_score = torch.zeros([bs, num_query_objects])
+            if num_gt > 0:
+                target_score = target_score.reshape(-1, 1)
+                target_score[index] = iou_score.float()
+                # target_score = paddle.scatter(target_score, index, iou_score)
+            target_score = target_score.reshape(bs, num_query_objects, 1) * target_label
+            loss_cls = varifocal_loss_with_logits(logits,
+                                                  target_score,
+                                                  target_label,
+                                                  num_gts)
+        else:
+            loss_cls = sigmoid_focal_loss(logits,
+                                          target_label,
+                                          num_gts)
+
+        return {name_class: loss_cls * self.loss_coeff['class']}
+
+    def _get_loss_bbox(self, boxes, gt_bbox, match_indices, num_gts,
+                       postfix=""):
+        # boxes: [b, query, 4], gt_bbox: list[[n, 4]]
+        name_bbox = "loss_bbox" + postfix
+        name_giou = "loss_giou" + postfix
+
+        loss = dict()
+        if sum(len(a) for a in gt_bbox) == 0:
+            loss[name_bbox] = torch.as_tensor([0.])
+            loss[name_giou] = torch.as_tensor([0.])
+            return loss
+
+        # prepare positive samples
+        src_bbox, target_bbox = self._get_src_target_assign(boxes, gt_bbox, match_indices)
+
+        # Compute L1 loss
+        loss[name_bbox] = self.loss_coeff['bbox'] * F.l1_loss(
+            src_bbox, target_bbox, reduction='sum') / num_gts
+        
+        # Compute GIoU loss
+        loss[name_giou] = self.giou_loss(
+            box_cxcywh_to_xyxy(src_bbox), box_cxcywh_to_xyxy(target_bbox))
+        loss[name_giou] = loss[name_giou].sum() / num_gts
+        loss[name_giou] = self.loss_coeff['giou'] * loss[name_giou]
+
+        return loss
+
+    def _get_loss_aux(self,
+                      boxes,
+                      logits,
+                      gt_bbox,
+                      gt_class,
+                      bg_index,
+                      num_gts,
+                      dn_match_indices=None,
+                      postfix=""):
+        loss_class = []
+        loss_bbox, loss_giou = [], []
+        if dn_match_indices is not None:
+            match_indices = dn_match_indices
+        for i, (aux_boxes, aux_logits) in enumerate(zip(boxes, logits)):
+            if dn_match_indices is None:
+                match_indices = self.matcher(
+                    aux_boxes,
+                    aux_logits,
+                    gt_bbox,
+                    gt_class,
+                    )
+            if self.use_vfl:
+                if sum(len(a) for a in gt_bbox) > 0:
+                    src_bbox, target_bbox = self._get_src_target_assign(
+                        aux_boxes.detach(), gt_bbox, match_indices)
+                    iou_score = bbox_iou(box_cxcywh_to_xyxy(src_bbox),
+                                         box_cxcywh_to_xyxy(target_bbox))
+                else:
+                    iou_score = None
+            else:
+                iou_score = None
+            loss_class.append(
+                self._get_loss_class(aux_logits, gt_class, match_indices,
+                                     bg_index, num_gts, postfix, iou_score)[
+                                         'loss_class' + postfix])
+            loss_ = self._get_loss_bbox(aux_boxes, gt_bbox, match_indices,
+                                        num_gts, postfix)
+            loss_bbox.append(loss_['loss_bbox' + postfix])
+            loss_giou.append(loss_['loss_giou' + postfix])
+
+        loss = {
+            "loss_class_aux" + postfix: sum(loss_class),
+            "loss_bbox_aux"  + postfix: sum(loss_bbox),
+            "loss_giou_aux"  + postfix: sum(loss_giou)
+        }
+
+        return loss
+
+    def _get_index_updates(self, num_query_objects, target, match_indices):
+        batch_idx = torch.cat([
+            torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)
+        ])
+        src_idx = torch.cat([src for (src, _) in match_indices])
+        src_idx += (batch_idx * num_query_objects)
+        target_assign = torch.cat([
+            torch.gather(t, 0, dst) for t, (_, dst) in zip(target, match_indices)
+        ])
+        return src_idx, target_assign
+
+    def _get_src_target_assign(self, src, target, match_indices):
+        src_assign = torch.cat([t[I] if len(I) > 0 else torch.zeros([0, t.shape[-1]])
+            for t, (I, _) in zip(src, match_indices)
+        ])
+
+        target_assign = torch.cat([t[J] if len(J) > 0 else torch.zeros([0, t.shape[-1]])
+            for t, (_, J) in zip(target, match_indices)
+        ])
+
+        return src_assign, target_assign
+
+    def _get_num_gts(self, targets):
+        num_gts = sum(len(a) for a in targets)
+        num_gts = torch.as_tensor([num_gts]).float()
+
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_gts)
+        num_gts = torch.clamp(num_gts / get_world_size(), min=1).item()
+
+        return num_gts
+
+    def _get_prediction_loss(self,
+                             boxes,
+                             logits,
+                             gt_bbox,
+                             gt_class,
+                             postfix="",
+                             dn_match_indices=None,
+                             num_gts=1):
+        if dn_match_indices is None:
+            match_indices = self.matcher(boxes, logits, gt_bbox, gt_class)
+        else:
+            match_indices = dn_match_indices
+
+        if self.use_vfl:
+            if sum(len(a) for a in gt_bbox) > 0:
+                src_bbox, target_bbox = self._get_src_target_assign(
+                    boxes.detach(), gt_bbox, match_indices)
+                iou_score = bbox_iou(box_cxcywh_to_xyxy(src_bbox),
+                                     box_cxcywh_to_xyxy(target_bbox))
+            else:
+                iou_score = None
+        else:
+            iou_score = None
+
+        loss = dict()
+        loss.update(
+            self._get_loss_class(logits, gt_class, match_indices,
+                                 self.num_classes, num_gts, postfix, iou_score))
+        loss.update(
+            self._get_loss_bbox(boxes, gt_bbox, match_indices, num_gts,
+                                postfix))
+
+        return loss
+
+    def forward(self,
+                boxes,
+                logits,
+                gt_bbox,
+                gt_class,
+                postfix="",
+                **kwargs):
+        r"""
+        Args:
+            boxes (Tensor): [l, b, query, 4]
+            logits (Tensor): [l, b, query, num_classes]
+            gt_bbox (List(Tensor)): list[[n, 4]]
+            gt_class (List(Tensor)): list[[n, 1]]
+            masks (Tensor, optional): [l, b, query, h, w]
+            gt_mask (List(Tensor), optional): list[[n, H, W]]
+            postfix (str): postfix of loss name
+        """
+
+        dn_match_indices = kwargs.get("dn_match_indices", None)
+        num_gts = kwargs.get("num_gts", None)
+        if num_gts is None:
+            num_gts = self._get_num_gts(gt_class)
+
+        total_loss = self._get_prediction_loss(
+            boxes[-1],
+            logits[-1],
+            gt_bbox,
+            gt_class,
+            postfix=postfix,
+            dn_match_indices=dn_match_indices,
+            num_gts=num_gts)
+
+        if self.aux_loss:
+            total_loss.update(
+                self._get_loss_aux(
+                    boxes[:-1],
+                    logits[:-1],
+                    gt_bbox,
+                    gt_class,
+                    self.num_classes,
+                    num_gts,
+                    dn_match_indices,
+                    postfix,
+                    ))
+
+        return total_loss
+
+class DINOLoss(DETRLoss):
+    def forward(self,
+                boxes,
+                logits,
+                gt_bbox,
+                gt_class,
+                postfix="",
+                dn_out_bboxes=None,
+                dn_out_logits=None,
+                dn_meta=None,
+                **kwargs):
+        num_gts = self._get_num_gts(gt_class)
+        total_loss = super(DINOLoss, self).forward(
+            boxes, logits, gt_bbox, gt_class, num_gts=num_gts)
+
+        if dn_meta is not None:
+            dn_positive_idx, dn_num_group = \
+                dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
+            assert len(gt_class) == len(dn_positive_idx)
+
+            # denoising match indices
+            dn_match_indices = self.get_dn_match_indices(
+                gt_class, dn_positive_idx, dn_num_group)
+
+            # compute denoising training loss
+            num_gts *= dn_num_group
+            dn_loss = super(DINOLoss, self).forward(
+                dn_out_bboxes,
+                dn_out_logits,
+                gt_bbox,
+                gt_class,
+                postfix="_dn",
+                dn_match_indices=dn_match_indices,
+                num_gts=num_gts)
+            total_loss.update(dn_loss)
+        else:
+            total_loss.update(
+                {k + '_dn': torch.as_tensor([0.])
+                 for k in total_loss.keys()})
+
+        return total_loss
+
+    @staticmethod
+    def get_dn_match_indices(labels, dn_positive_idx, dn_num_group):
+        dn_match_indices = []
+        for i in range(len(labels)):
+            num_gt = len(labels[i])
+            if num_gt > 0:
+                gt_idx = torch.arange(num_gt).long()
+                gt_idx = gt_idx.tile([dn_num_group])
+                assert len(dn_positive_idx[i]) == len(gt_idx)
+                dn_match_indices.append((dn_positive_idx[i], gt_idx))
+            else:
+                dn_match_indices.append((torch.zeros([0], dtype="int64"),
+                                         torch.zeros([0], dtype="int64")))
+        return dn_match_indices

+ 240 - 0
models/detectors/rtdetr/loss_utils.py

@@ -0,0 +1,240 @@
+import math
+import torch
+import torch.nn.functional as F
+import torch.distributed as dist
+from torchvision.ops.boxes import box_area
+
+
+# ------------------------- For loss -------------------------
+## FocalLoss
+def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
+    """
+    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+        alpha: (optional) Weighting factor in range (0,1) to balance
+                positive vs negative examples. Default = -1 (no weighting).
+        gamma: Exponent of the modulating factor (1 - p_t) to
+               balance easy vs hard examples.
+    Returns:
+        Loss tensor
+    """
+    prob = inputs.sigmoid()
+    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+    p_t = prob * targets + (1 - prob) * (1 - targets)
+    loss = ce_loss * ((1 - p_t) ** gamma)
+
+    if alpha >= 0:
+        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+        loss = alpha_t * loss
+
+    return loss.sum() / num_boxes
+
+## Variable FocalLoss
+def varifocal_loss_with_logits(pred_logits,
+                               gt_score,
+                               label,
+                               normalizer=1.0,
+                               alpha=0.75,
+                               gamma=2.0):
+    pred_score = F.sigmoid(pred_logits)
+    weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label
+    loss = F.binary_cross_entropy_with_logits(pred_logits, gt_score, reduction='none')
+    loss *= weight
+
+    return loss.sum() / normalizer
+
+## InverseSigmoid
+def inverse_sigmoid(x, eps=1e-5):
+    x = x.clamp(min=0, max=1)
+    x1 = x.clamp(min=eps)
+    x2 = (1 - x).clamp(min=eps)
+    return torch.log(x1/x2)
+
+## GIoU loss
+class GIoULoss(object):
+    """ Modified GIoULoss from Paddle-Paddle"""
+    def __init__(self, eps=1e-10, reduction='none'):
+        self.eps = eps
+        self.reduction = reduction
+        assert reduction in ('none', 'mean', 'sum')
+
+    def bbox_overlap(self, box1, box2, eps=1e-10):
+        """calculate the iou of box1 and box2
+        Args:
+            box1 (Tensor): box1 with the shape (..., 4)
+            box2 (Tensor): box1 with the shape (..., 4)
+            eps (float): epsilon to avoid divide by zero
+        Return:
+            iou (Tensor): iou of box1 and box2
+            overlap (Tensor): overlap of box1 and box2
+            union (Tensor): union of box1 and box2
+        """
+        x1, y1, x2, y2 = box1
+        x1g, y1g, x2g, y2g = box2
+
+        xkis1 = torch.max(x1, x1g)
+        ykis1 = torch.max(y1, y1g)
+        xkis2 = torch.min(x2, x2g)
+        ykis2 = torch.min(y2, y2g)
+        w_inter = (xkis2 - xkis1).clip(0)
+        h_inter = (ykis2 - ykis1).clip(0)
+        overlap = w_inter * h_inter
+
+        area1 = (x2 - x1) * (y2 - y1)
+        area2 = (x2g - x1g) * (y2g - y1g)
+        union = area1 + area2 - overlap + eps
+        iou = overlap / union
+
+        return iou, overlap, union
+
+    def __call__(self, pbox, gbox):
+        # x1, y1, x2, y2 = torch.split(pbox, 4, dim=-1)
+        # x1g, y1g, x2g, y2g = torch.split(gbox, 4, dim=-1)
+        x1, y1, x2, y2 = torch.chunk(pbox, 4, dim=-1)
+        x1g, y1g, x2g, y2g = torch.chunk(gbox, 4, dim=-1)
+        box1 = [x1, y1, x2, y2]
+        box2 = [x1g, y1g, x2g, y2g]
+        iou, _, union = self.bbox_overlap(box1, box2, self.eps)
+        xc1 = torch.min(x1, x1g)
+        yc1 = torch.min(y1, y1g)
+        xc2 = torch.max(x2, x2g)
+        yc2 = torch.max(y2, y2g)
+
+        area_c = (xc2 - xc1) * (yc2 - yc1) + self.eps
+        miou = iou - ((area_c - union) / area_c)
+        giou = 1 - miou
+
+        if self.reduction == 'none':
+            loss = giou
+        elif self.reduction == 'sum':
+            loss = giou.sum()
+        elif self.reduction == 'mean':
+            loss = giou.mean()
+
+        return loss
+
+
+# ------------------------- For box -------------------------
+def box_cxcywh_to_xyxy(x):
+    x_c, y_c, w, h = x.unbind(-1)
+    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
+         (x_c + 0.5 * w), (y_c + 0.5 * h)]
+    return torch.stack(b, dim=-1)
+
+def box_xyxy_to_cxcywh(x):
+    x0, y0, x1, y1 = x.unbind(-1)
+    b = [(x0 + x1) / 2, (y0 + y1) / 2,
+         (x1 - x0), (y1 - y0)]
+    return torch.stack(b, dim=-1)
+
+def box_iou(boxes1, boxes2):
+    area1 = box_area(boxes1)
+    area2 = box_area(boxes2)
+
+    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
+    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]
+
+    wh = (rb - lt).clamp(min=0)  # [N,M,2]
+    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]
+
+    union = area1[:, None] + area2 - inter
+
+    iou = inter / union
+    return iou, union
+
+def generalized_box_iou(boxes1, boxes2):
+    """
+    Generalized IoU from https://giou.stanford.edu/
+
+    The boxes should be in [x0, y0, x1, y1] format
+
+    Returns a [N, M] pairwise matrix, where N = len(boxes1)
+    and M = len(boxes2)
+    """
+    # degenerate boxes gives inf / nan results
+    # so do an early check
+    assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
+    assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
+    iou, union = box_iou(boxes1, boxes2)
+
+    lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
+    rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
+
+    wh = (rb - lt).clamp(min=0)  # [N,M,2]
+    area = wh[:, :, 0] * wh[:, :, 1]
+
+    return iou - (area - union) / area
+
+def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9):
+    """Modified from Paddle-paddle
+    Args:
+        box1 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
+        box2 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
+        giou (bool): whether use giou or not, default False
+        diou (bool): whether use diou or not, default False
+        ciou (bool): whether use ciou or not, default False
+        eps (float): epsilon to avoid divide by zero
+    Return:
+        iou (Tensor): iou of box1 and box1, with the shape [b, na, h, w, 1]
+    """
+    px1, py1, px2, py2 = torch.chunk(box1, 4, -1)
+    gx1, gy1, gx2, gy2 = torch.chunk(box2, 4, -1)
+    x1 = torch.max(px1, gx1)
+    y1 = torch.max(py1, gy1)
+    x2 = torch.min(px2, gx2)
+    y2 = torch.min(py2, gy2)
+
+    overlap = ((x2 - x1).clamp(0)) * ((y2 - y1).clamp(0))
+
+    area1 = (px2 - px1) * (py2 - py1)
+    area1 = area1.clamp(0)
+
+    area2 = (gx2 - gx1) * (gy2 - gy1)
+    area2 = area2.clamp(0)
+
+    union = area1 + area2 - overlap + eps
+    iou = overlap / union
+
+    if giou or ciou or diou:
+        # convex w, h
+        cw = torch.max(px2, gx2) - torch.min(px1, gx1)
+        ch = torch.max(py2, gy2) - torch.min(py1, gy1)
+        if giou:
+            c_area = cw * ch + eps
+            return iou - (c_area - union) / c_area
+        else:
+            # convex diagonal squared
+            c2 = cw**2 + ch**2 + eps
+            # center distance
+            rho2 = ((px1 + px2 - gx1 - gx2)**2 + (py1 + py2 - gy1 - gy2)**2) / 4
+            if diou:
+                return iou - rho2 / c2
+            else:
+                w1, h1 = px2 - px1, py2 - py1 + eps
+                w2, h2 = gx2 - gx1, gy2 - gy1 + eps
+                delta = torch.atan(w1 / h1) - torch.atan(w2 / h2)
+                v = (4 / math.pi**2) * torch.pow(delta, 2)
+                alpha = v / (1 + eps - iou + v)
+                alpha.requires_grad_ = False
+                return iou - (rho2 / c2 + v * alpha)
+    else:
+        return iou
+
+
+# ------------------------- For distributed -------------------------
+def is_dist_avail_and_initialized():
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+def get_world_size():
+    if not is_dist_avail_and_initialized():
+        return 1
+    return dist.get_world_size()

+ 52 - 0
models/detectors/rtdetr/matcher.py

@@ -0,0 +1,52 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from scipy.optimize import linear_sum_assignment
+
+try:
+    from .loss_utils import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh, generalized_box_iou
+except:
+    from  loss_utils import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh, generalized_box_iou
+
+
+class HungarianMatcher(nn.Module):
+    def __init__(self, cost_class, cost_bbox, cost_giou, alpha=0.25, gamma=2.0):
+        super().__init__()
+        self.cost_class = cost_class
+        self.cost_bbox = cost_bbox
+        self.cost_giou = cost_giou
+        self.alpha = alpha
+        self.gamma = gamma
+
+    @torch.no_grad()
+    def forward(self, pred_boxes, pred_logits, gt_boxes, gt_labels):
+        bs, num_queries = pred_logits.shape[:2]
+        # [B, Nq, C] -> [BNq, C]
+        out_prob = pred_logits.flatten(0, 1).sigmoid()
+        out_bbox = pred_boxes.flatten(0, 1)
+
+        # List[B, M, C] -> [BM, C]
+        tgt_ids = torch.cat(gt_labels).long()
+        tgt_bbox = torch.cat(gt_boxes).float()
+
+        # -------------------- Classification cost --------------------
+        neg_cost_class = (1 - self.alpha) * (out_prob ** self.gamma) * (-(1 - out_prob + 1e-8).log())
+        pos_cost_class = self.alpha * ((1 - out_prob) ** self.gamma) * (-(out_prob + 1e-8).log())
+        cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
+
+        # -------------------- Regression cost --------------------
+        ## L1 cost: [Nq, M]
+        cost_bbox = torch.cdist(out_bbox, box_xyxy_to_cxcywh(tgt_bbox).to(out_bbox.device), p=1)
+        ## GIoU cost: Nq, M]
+        cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), tgt_bbox.to(out_bbox.device))
+
+        # Final cost: [B, Nq, M]
+        C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
+        C = C.view(bs, num_queries, -1).cpu()
+
+        # Label assignment
+        sizes = [len(t) for t in gt_boxes]
+        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
+
+        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
+    

+ 202 - 2
models/detectors/rtdetr/rtdetr.py

@@ -1,5 +1,205 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .rtdetr_encoder import build_image_encoder
+    from .rtdetr_decoder import build_transformer
+except:
+    from  rtdetr_encoder import build_image_encoder
+    from  rtdetr_decoder import build_transformer
+
+
 # Real-time Transformer-based Object Detector
+class RT_DETR(nn.Module):
+    def __init__(self,
+                 cfg,
+                 num_classes = 80,
+                 conf_thresh = 0.1,
+                 topk        = 100,
+                 deploy      = False,
+                 no_multi_labels = False,
+                 ):
+        super().__init__()
+        # ----------- Basic setting -----------
+        self.num_classes = num_classes
+        self.num_topk = topk
+        self.conf_thresh = conf_thresh
+        self.no_multi_labels = no_multi_labels
+        self.deploy = deploy
+
+        # ----------- Network setting -----------
+        ## Image encoder
+        self.image_encoder = build_image_encoder(cfg)
+        self.fpn_dims = self.image_encoder.fpn_dims
+
+        ## Detect decoder
+        self.detect_decoder = build_transformer(cfg, self.fpn_dims, num_classes, return_intermediate=self.training)
+
+    def post_process(self, box_pred, cls_pred):
+        if self.no_multi_labels:
+            # [M,]
+            scores, labels = torch.max(cls_pred.sigmoid(), dim=1)
+
+            # Keep top k top scoring indices only.
+            num_topk = min(self.num_topk, box_pred.size(0))
+
+            # Topk candidates
+            predicted_prob, topk_idxs = scores.sort(descending=True)
+            topk_scores = predicted_prob[:num_topk]
+            topk_idxs = topk_idxs[:num_topk]
+
+            # Filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            topk_idxs = topk_idxs[keep_idxs]
+
+            # Top-k results
+            topk_scores = topk_scores[keep_idxs]
+            topk_labels = labels[topk_idxs]
+            topk_bboxes = box_pred[topk_idxs]
+
+            return topk_bboxes, topk_scores, topk_labels
+        else:
+            # Top-k select
+            cls_pred = cls_pred[0].flatten().sigmoid_()
+            box_pred = box_pred[0]
+
+            # Keep top k top scoring indices only.
+            num_topk = min(self.num_topk, box_pred.size(0))
+
+            # Topk candidates
+            predicted_prob, topk_idxs = cls_pred.sort(descending=True)
+            topk_scores = predicted_prob[:num_topk]
+            topk_idxs = topk_idxs[:self.num_topk]
+
+            # Filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            scores = topk_scores[keep_idxs]
+            topk_idxs = topk_idxs[keep_idxs]
+            topk_box_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+
+            ## Top-k results
+            topk_scores = predicted_prob[:self.num_topk]
+            topk_labels = topk_idxs % self.num_classes
+            topk_bboxes = box_pred[topk_box_idxs]
+
+        return topk_bboxes, topk_scores, topk_labels
+    
+    def forward(self, x, targets=None):
+        # ----------- Image Encoder -----------
+        pyramid_feats = self.image_encoder(x)
+
+        # ----------- Transformer -----------
+        transformer_outputs = self.detect_decoder(pyramid_feats, targets)
+        pred_boxes, pred_logits, enc_topk_bboxes, enc_topk_logits, dn_meta = transformer_outputs
+
+        if self.training:
+            return transformer_outputs
+        else:
+            box_preds = pred_boxes[-1]
+            cls_preds = pred_logits[-1]
+            
+            # TODO: post-process
+            bboxes, scores, labels = self.post_process(box_preds, cls_preds)
+
+            return bboxes, scores, labels
+        
+        # ----------- Head -----------
+        outputs = self.detect_head(pred_boxes, pred_logits, enc_topk_bboxes, enc_topk_logits, dn_meta, targets)
+
+        if self.training:
+            outputs_dict = outputs
+        else:
+            pred_boxes, pred_logits = outputs[0], outputs[1]
+            return pred_boxes, pred_logits
+            
+        return outputs_dict
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    from loss import build_criterion
+
+    # Model config
+    cfg = {
+        'width': 1.0,
+        'depth': 1.0,
+        'out_stride': [8, 16, 32],
+        # Image Encoder - Backbone
+        'backbone': 'resnet18',
+        'backbone_norm': 'BN',
+        'res5_dilation': False,
+        'pretrained': True,
+        'pretrained_weight': 'imagenet1k_v1',
+        # Image Encoder - FPN
+        'fpn': 'hybrid_encoder',
+        'fpn_act': 'silu',
+        'fpn_norm': 'BN',
+        'fpn_depthwise': False,
+        'hidden_dim': 256,
+        'en_num_heads': 8,
+        'en_num_layers': 1,
+        'en_mlp_ratio': 4.0,
+        'en_dropout': 0.1,
+        'pe_temperature': 10000.,
+        'en_act': 'gelu',
+        # Transformer Decoder
+        'transformer': 'rtdetr_transformer',
+        'hidden_dim': 256,
+        'de_num_heads': 8,
+        'de_num_layers': 6,
+        'de_mlp_ratio': 4.0,
+        'de_dropout': 0.0,
+        'de_act': 'gelu',
+        'de_num_points': 4,
+        'num_queries': 300,
+        'learnt_init_query': False,
+        'pe_temperature': 10000.,
+        'dn_num_denoising': 100,
+        'dn_label_noise_ratio': 0.5,
+        'dn_box_noise_scale': 1,
+        # Head
+        'det_head': 'dino_head',
+        # Matcher
+        'matcher_hpy': {'cost_class': 2.0,
+                        'cost_bbox': 5.0,
+                        'cost_giou': 2.0,},
+        # Loss
+        'use_vfl': True,
+        'loss_coeff': {'class': 1,
+                       'bbox': 5,
+                       'giou': 2,
+                       'no_object': 0.1,},
+        }
+    bs = 1
+    # Create a batch of images & targets
+    image = torch.randn(bs, 3, 640, 640)
+    targets = [{
+        'labels': torch.tensor([2, 4, 5, 8]).long(),
+        'boxes':  torch.tensor([[0, 0, 10, 10], [12, 23, 56, 70], [0, 10, 20, 30], [50, 60, 55, 150]]).float() / 640.
+    }] * bs
+
+    # Create model
+    model = RT_DETR(cfg, num_classes=80)
+    model.train()
+
+    # Create criterion
+    criterion = build_criterion(cfg, num_classes=80)
+
+    # Model inference
+    t0 = time.time()
+    outputs = model(image, targets)
+    t1 = time.time()
+    print('Infer time: ', t1 - t0)
 
+    # Compute loss
+    loss = criterion(*outputs, targets)
+    for k in loss.keys():
+        print("{} : {}".format(k, loss[k].item()))
 
-class RT_DETR():
-    pass
+    print('==============================')
+    model.eval()
+    flops, params = profile(model, inputs=(image, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 90 - 18
models/detectors/rtdetr/rtdetr_decoder.py

@@ -6,13 +6,38 @@ from torch.nn.init import constant_, xavier_uniform_, uniform_
 from typing import List
 
 try:
-    from .basic_modules.basic import BasicConv, MLP, DeformableTransformerDecoder
+    from .basic_modules.basic import BasicConv, MLP
+    from .basic_modules.transformer import DeformableTransformerDecoder
     from .basic_modules.dn_compoments import get_contrastive_denoising_training_group
 except:
-    from  basic_modules.basic import BasicConv, MLP, DeformableTransformerDecoder
+    from  basic_modules.basic import BasicConv, MLP
+    from  basic_modules.transformer import DeformableTransformerDecoder
     from  basic_modules.dn_compoments import get_contrastive_denoising_training_group
 
 
+def build_transformer(cfg, in_dims, num_classes, return_intermediate=False):
+    if cfg['transformer'] == 'rtdetr_transformer':
+        return RTDETRTransformer(in_dims             = in_dims,
+                                 hidden_dim          = cfg['hidden_dim'],
+                                 strides             = cfg['out_stride'],
+                                 num_classes         = num_classes,
+                                 num_queries         = cfg['num_queries'],
+                                 pos_embed_type      = 'sine',
+                                 num_heads           = cfg['de_num_heads'],
+                                 num_layers          = cfg['de_num_layers'],
+                                 num_levels          = len(cfg['out_stride']),
+                                 num_points          = cfg['de_num_points'],
+                                 mlp_ratio           = cfg['de_mlp_ratio'],
+                                 dropout             = cfg['de_dropout'],
+                                 act_type            = cfg['de_act'],
+                                 return_intermediate = return_intermediate,
+                                 num_denoising       = cfg['dn_num_denoising'],
+                                 label_noise_ratio   = cfg['dn_label_noise_ratio'],
+                                 box_noise_scale     = cfg['dn_box_noise_scale'],
+                                 learnt_init_query   = cfg['learnt_init_query'],
+                                 )
+
+
 # ----------------- Dencoder for Detection task -----------------
 ## RTDETR's Transformer for Detection task
 class RTDETRTransformer(nn.Module):
@@ -24,14 +49,12 @@ class RTDETRTransformer(nn.Module):
                  num_classes    :int  = 80,
                  num_queries    :int  = 300,
                  pos_embed_type :str  = 'sine',
-                 trainable      :bool = False,
                  # transformer parameters
                  num_heads      :int   = 8,
                  num_layers     :int   = 1,
                  num_levels     :int   = 3,
                  num_points     :int   = 4,
                  mlp_ratio      :float = 4.0,
-                 pe_temperature :float = 10000.,
                  dropout        :float = 0.1,
                  act_type       :str   = "relu",
                  return_intermediate :bool = False,
@@ -46,7 +69,6 @@ class RTDETRTransformer(nn.Module):
         ## Basic parameters
         self.in_dims = in_dims
         self.strides = strides
-        self.trainable = trainable
         self.num_queries = num_queries
         self.pos_embed_type = pos_embed_type
         self.num_classes = num_classes
@@ -59,7 +81,6 @@ class RTDETRTransformer(nn.Module):
         self.mlp_ratio  = mlp_ratio
         self.dropout    = dropout
         self.act_type   = act_type
-        self.pe_temperature = pe_temperature
         self.return_intermediate = return_intermediate
         ## Denoising parameters
         self.num_denoising = num_denoising
@@ -82,9 +103,8 @@ class RTDETRTransformer(nn.Module):
             num_levels = num_levels,
             num_points = num_points,
             mlp_ratio  = mlp_ratio,
-            pe_temperature = pe_temperature,
-            dropout        = dropout,
-            act_type       = act_type,
+            dropout    = dropout,
+            act_type   = act_type,
             return_intermediate = return_intermediate
             )
         
@@ -142,8 +162,8 @@ class RTDETRTransformer(nn.Module):
             xavier_uniform_(self.tgt_embed.weight)
         xavier_uniform_(self.query_pos_head.layers[0].weight)
         xavier_uniform_(self.query_pos_head.layers[1].weight)
-        for l in self.input_proj:
-            xavier_uniform_(l[0].weight)
+        for l in self.input_proj_layers:
+            xavier_uniform_(l.conv.weight)
 
     def generate_anchors(self, spatial_shapes, grid_size=0.05):
         anchors = []
@@ -197,26 +217,27 @@ class RTDETRTransformer(nn.Module):
         memory = torch.where(valid_mask, memory, torch.as_tensor(0.))
         output_memory = self.enc_output(memory)
 
+        # [bs, num_quries, c]
         enc_outputs_class = self.enc_class_head(output_memory)
         enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
 
         topk = self.num_queries
-        topk_ind = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
-        reference_points_unact = torch.gather(enc_outputs_coord_unact, 1, topk_ind.unsqueeze(-1).repeat(1, 1, 4))
+        topk_ind = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1]  # [bs, topk]
+        reference_points_unact = torch.gather(enc_outputs_coord_unact, 1, topk_ind.unsqueeze(-1).repeat(1, 1, 4)) # [bs, topk, 4]
         enc_topk_bboxes = F.sigmoid(reference_points_unact)
 
         if denoising_bbox_unact is not None:
             reference_points_unact = torch.cat(
                 [denoising_bbox_unact, reference_points_unact], 1)
-        if self.trainable:
+        if self.training:
             reference_points_unact = reference_points_unact.detach()
-        enc_topk_logits = torch.gather(enc_outputs_class, 1, topk_ind.unsqueeze(-1).repeat(1, 1, self.num_classes))
+        enc_topk_logits = torch.gather(enc_outputs_class, 1, topk_ind.unsqueeze(-1).repeat(1, 1, self.num_classes))  # [bs, topk, nc]
 
         # extract region features
         if self.learnt_init_query:
             target = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
         else:
-            target = torch.gather(output_memory, 1, topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[1]))
+            target = torch.gather(output_memory, 1, topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
             if self.training:
                 target = target.detach()
         if denoising_class is not None:
@@ -255,8 +276,7 @@ class RTDETRTransformer(nn.Module):
                                                           self.query_pos_head,
                                                           attn_mask)
         
-        return (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits,
-                dn_meta)
+        return out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta
 
 
 # ----------------- Dencoder for Segmentation task -----------------
@@ -279,3 +299,55 @@ class PosTransformerDecoder(nn.Module):
 
     def forward(self, x):
         return
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'out_stride': [8, 16, 32],
+        # Transformer Decoder
+        'transformer': 'rtdetr_transformer',
+        'hidden_dim': 256,
+        'de_num_heads': 8,
+        'de_num_layers': 6,
+        'de_mlp_ratio': 4.0,
+        'de_dropout': 0.1,
+        'de_act': 'gelu',
+        'de_num_points': 4,
+        'num_queries': 300,
+        'learnt_init_query': False,
+        'pe_temperature': 10000.,
+        'dn_num_denoising': 100,
+        'dn_label_noise_ratio': 0.5,
+        'dn_box_noise_scale': 1,
+    }
+    bs = 1
+    hidden_dim = cfg['hidden_dim']
+    in_dims = [hidden_dim] * 3
+    targets = [{
+        'labels': torch.tensor([2, 4, 5, 8]).long(),
+        'boxes':  torch.tensor([[0, 0, 10, 10], [12, 23, 56, 70], [0, 10, 20, 30], [50, 60, 55, 150]]).float()
+    }] * bs
+    pyramid_feats = [torch.randn(bs, hidden_dim, 80, 80),
+                     torch.randn(bs, hidden_dim, 40, 40),
+                     torch.randn(bs, hidden_dim, 20, 20)]
+    model = build_transformer(cfg, in_dims, 80, True)
+    model.train()
+
+    t0 = time.time()
+    outputs = model(pyramid_feats, targets)
+    out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta = outputs
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print(out_bboxes.shape)
+    print(out_logits.shape)
+    print(enc_topk_bboxes.shape)
+    print(enc_topk_logits.shape)
+
+    print('==============================')
+    model.eval()
+    flops, params = profile(model, inputs=(pyramid_feats, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 9 - 5
models/detectors/rtdetr/rtdetr_encoder.py

@@ -11,11 +11,11 @@ except:
 
 
 # ----------------- Image Encoder -----------------
-def build_image_encoder(cfg, trainable=False):
-    return ImageEncoder(cfg, trainable)
+def build_image_encoder(cfg):
+    return ImageEncoder(cfg)
 
 class ImageEncoder(nn.Module):
-    def __init__(self, cfg, trainable=False):
+    def __init__(self, cfg):
         super().__init__()
         # ---------------- Basic settings ----------------
         ## Basic parameters
@@ -27,10 +27,11 @@ class ImageEncoder(nn.Module):
         
         # ---------------- Network settings ----------------
         ## Backbone Network
-        self.backbone, fpn_feat_dims = build_backbone(cfg, pretrained=cfg['pretrained']&trainable)
+        self.backbone, fpn_feat_dims = build_backbone(cfg, pretrained=cfg['pretrained']&self.training)
 
         ## Feature Pyramid Network
         self.fpn = build_fpn(cfg, fpn_feat_dims, self.hidden_dim)
+        self.fpn_dims = self.fpn.out_dims
         
     def forward(self, x):
         pyramid_feats = self.backbone(x)
@@ -66,7 +67,8 @@ if __name__ == '__main__':
         'en_act': 'gelu',
     }
     x = torch.rand(2, 3, 640, 640)
-    model = build_image_encoder(cfg, True)
+    model = build_image_encoder(cfg)
+    model.train()
 
     t0 = time.time()
     outputs = model(x)
@@ -76,6 +78,8 @@ if __name__ == '__main__':
         print(out.shape)
 
     print('==============================')
+    model.eval()
+    x = torch.rand(1, 3, 640, 640)
     flops, params = profile(model, inputs=(x, ), verbose=False)
     print('==============================')
     print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))

+ 13 - 0
utils/misc.py

@@ -231,6 +231,19 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f
 
     return loss.mean(1).sum() / num_boxes
 
+## Variable FocalLoss
+def varifocal_loss_with_logits(pred_logits,
+                               gt_score,
+                               label,
+                               normalizer=1.0,
+                               alpha=0.75,
+                               gamma=2.0):
+    pred_score = F.sigmoid(pred_logits)
+    weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label
+    loss = F.binary_cross_entropy_with_logits(
+        pred_logits, gt_score, weight=weight, reduction='none')
+    return loss.mean(1).sum() / normalizer
+
 ## InverseSigmoid
 def inverse_sigmoid(x, eps=1e-5):
     x = x.clamp(min=0, max=1)