yjh0410 1 год назад
Родитель
Сommit
dfb6c1ecec

+ 0 - 3
config/model_config/rtdetr_config.py

@@ -11,7 +11,6 @@ rtdetr_cfg = {
         ## Image Encoder - Backbone
         'backbone': 'resnet18',
         'backbone_norm': 'FrozeBN',
-        'res5_dilation': False,
         'pretrained': True,
         'pretrained_weight': 'imagenet1k_v1',
         'freeze_at': 0,
@@ -72,7 +71,6 @@ rtdetr_cfg = {
         ## Image Encoder - Backbone
         'backbone': 'resnet50',
         'backbone_norm': 'FrozeBN',
-        'res5_dilation': False,
         'pretrained': True,
         'pretrained_weight': 'imagenet1k_v1',
         'freeze_at': 0,
@@ -133,7 +131,6 @@ rtdetr_cfg = {
         ## Image Encoder - Backbone
         'backbone': 'resnet101',
         'backbone_norm': 'FrozeBN',
-        'res5_dilation': False,
         'pretrained': True,
         'pretrained_weight': 'imagenet1k_v1',
         'freeze_at': 0,

+ 3 - 8
models/detectors/rtdetr/basic_modules/backbone.py

@@ -46,7 +46,6 @@ class ResNet(nn.Module):
     """ResNet backbone with frozen BatchNorm."""
     def __init__(self,
                  name: str,
-                 res5_dilation: bool,
                  norm_type: str,
                  pretrained_weights: str = "imagenet1k_v1",
                  freeze_at: int = -1,
@@ -71,9 +70,7 @@ class ResNet(nn.Module):
         elif norm_type == 'FrozeBN':
             norm_layer = FrozenBatchNorm2d
         # Backbone
-        backbone = getattr(torchvision.models, name)(
-            replace_stride_with_dilation=[False, False, res5_dilation],
-            norm_layer=norm_layer, weights=pretrained_weights)
+        backbone = getattr(torchvision.models, name)(norm_layer=norm_layer, weights=pretrained_weights)
         return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
         self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
         self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
@@ -98,7 +95,6 @@ class ResNet(nn.Module):
 def build_resnet(cfg, pretrained_weight=None):
     # ResNet series
     backbone = ResNet(cfg['backbone'],
-                      cfg['res5_dilation'],
                       cfg['backbone_norm'],
                       pretrained_weight,
                       cfg['freeze_at'],
@@ -120,7 +116,6 @@ if __name__ == '__main__':
     cfg = {
         'backbone':      'resnet18',
         'backbone_norm': 'BN',
-        'res5_dilation': False,
         'pretrained': True,
         'freeze_at': -1,
         'freeze_stem_only': True,
@@ -134,6 +129,6 @@ if __name__ == '__main__':
     for y in output:
         print(y.size())
 
-    for n, p in model.named_parameters():
-        print(n.split(".")[-1])
+    # for n, p in model.named_parameters():
+    #     print(n.split(".")[-1])
 

+ 65 - 41
models/detectors/rtpdetr/basic_modules/backbone.py

@@ -2,36 +2,32 @@ import torch
 import torchvision
 from torch import nn
 from torchvision.models._utils import IntermediateLayerGetter
-from torchvision.models.resnet import (ResNet18_Weights,
-                                       ResNet34_Weights,
-                                       ResNet50_Weights,
-                                       ResNet101_Weights)
+
 try:
     from .basic import FrozenBatchNorm2d
 except:
     from basic  import FrozenBatchNorm2d
    
 
-# IN1K pretrained weights
+# IN1K MIM pretrained weights (from SparK: https://github.com/keyu-tian/SparK)
 pretrained_urls = {
     # ResNet series
-    'resnet18':  ResNet18_Weights,
-    'resnet34':  ResNet34_Weights,
-    'resnet50':  ResNet50_Weights,
-    'resnet101': ResNet101_Weights,
+    'resnet18':  None,
+    'resnet34':  None,
+    'resnet50':  "https://github.com/yjh0410/RT-ODLab/releases/download/backbone_weight/resnet50_in1k_spark_pretrained_timm_style.pth",
+    'resnet101': None,
     # ShuffleNet series
 }
 
 
 # ----------------- Model functions -----------------
 ## Build backbone network
-def build_backbone(cfg, pretrained):
+def build_backbone(cfg, pretrained=False):
     print('==============================')
     print('Backbone: {}'.format(cfg['backbone']))
     # ResNet
     if 'resnet' in cfg['backbone']:
-        pretrained_weight = cfg['pretrained_weight'] if pretrained else None
-        model, feats = build_resnet(cfg, pretrained_weight)
+        model, feats = build_resnet(cfg, pretrained)
     elif 'svnetv2' in cfg['backbone']:
         pretrained_weight = cfg['pretrained_weight'] if pretrained else None
         model, feats = build_scnetv2(cfg, pretrained_weight)
@@ -44,37 +40,62 @@ def build_backbone(cfg, pretrained):
 # ----------------- ResNet Backbone -----------------
 class ResNet(nn.Module):
     """ResNet backbone with frozen BatchNorm."""
-    def __init__(self, name: str, res5_dilation: bool, norm_type: str, pretrained_weights: str = "imagenet1k_v1"):
+    def __init__(self,
+                 name: str,
+                 norm_type: str,
+                 pretrained: bool = False,
+                 freeze_at: int = -1,
+                 freeze_stem_only: bool = False):
         super().__init__()
         # Pretrained
-        assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
-        if pretrained_weights is not None:
-            if name in ('resnet18', 'resnet34'):
-                pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
-            else:
-                if pretrained_weights == "imagenet1k_v1":
-                    pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
-                else:
-                    pretrained_weights = pretrained_urls[name].IMAGENET1K_V2
-        else:
-            pretrained_weights = None
-        print('ImageNet pretrained weight: ', pretrained_weights)
         # Norm layer
         if norm_type == 'BN':
             norm_layer = nn.BatchNorm2d
         elif norm_type == 'FrozeBN':
             norm_layer = FrozenBatchNorm2d
         # Backbone
-        backbone = getattr(torchvision.models, name)(
-            replace_stride_with_dilation=[False, False, res5_dilation],
-            norm_layer=norm_layer, weights=pretrained_weights)
+        backbone = getattr(torchvision.models, name)(norm_layer=norm_layer,)
         return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
         self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
         self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
+        
+        # Load pretrained
+        if pretrained:
+            self.load_pretrained(name)
+
         # Freeze
-        for name, parameter in backbone.named_parameters():
-            if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
-                parameter.requires_grad_(False)
+        if freeze_at >= 0:
+            for name, parameter in backbone.named_parameters():
+                if freeze_stem_only:
+                    if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+                        parameter.requires_grad_(False)
+                else:
+                    if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+                        parameter.requires_grad_(False)
+
+    def load_pretrained(self, name):
+        url = pretrained_urls[name]
+        if url is not None:
+            print('Loading pretrained weight from : {}'.format(url))
+            # checkpoint state dict
+            checkpoint_state_dict = torch.hub.load_state_dict_from_url(
+                url=url, map_location="cpu", check_hash=True)
+            # model state dict
+            model_state_dict = self.body.state_dict()
+            # check
+            for k in list(checkpoint_state_dict.keys()):
+                if k in model_state_dict:
+                    shape_model = tuple(model_state_dict[k].shape)
+                    shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+                    if shape_model != shape_checkpoint:
+                        checkpoint_state_dict.pop(k)
+                else:
+                    checkpoint_state_dict.pop(k)
+                    print('Unused key: ', k)
+            # load the weight
+            self.body.load_state_dict(checkpoint_state_dict)
+        else:
+            print('No backbone pretrained for {}.'.format(name))
 
     def forward(self, x):
         xs = self.body(x)
@@ -84,9 +105,13 @@ class ResNet(nn.Module):
 
         return fmp_list
 
-def build_resnet(cfg, pretrained_weight=None):
+def build_resnet(cfg, pretrained=False):
     # ResNet series
-    backbone = ResNet(cfg['backbone'], cfg['res5_dilation'], cfg['backbone_norm'], pretrained_weight)
+    backbone = ResNet(cfg['backbone'],
+                      cfg['backbone_norm'],
+                      pretrained,
+                      cfg['freeze_at'],
+                      cfg['freeze_stem_only'])
 
     return backbone, backbone.feat_dims
 
@@ -102,20 +127,19 @@ def build_scnetv2(cfg, pretrained_weight=None):
 
 if __name__ == '__main__':
     cfg = {
-        'backbone':      'resnet18',
-        'backbone_norm': 'BN',
-        'res5_dilation': False,
+        'backbone': 'resnet50',
+        'backbone_norm': 'FrozeBN',
         'pretrained': True,
-        'pretrained_weight': 'imagenet1k_v1',
+        'freeze_at': 0,
+        'freeze_stem_only': False,
     }
     model, feat_dim = build_backbone(cfg, cfg['pretrained'])
+    model.eval()
     print(feat_dim)
 
-    x = torch.randn(2, 3, 320, 320)
+    x = torch.ones(2, 3, 320, 320)
     output = model(x)
     for y in output:
         print(y.size())
-
-    for n, p in model.named_parameters():
-        print(n.split(".")[-1])
+    print(output[-1])
 

+ 171 - 32
models/detectors/rtpdetr/basic_modules/basic.py

@@ -1,8 +1,80 @@
 import math
+import warnings
 import torch
 import torch.nn as nn
 
 
+def _trunc_normal_(tensor, mean, std, a, b):
+    """Copy from timm"""
+    def norm_cdf(x):
+        return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+    if (mean < a - 2 * std) or (mean > b + 2 * std):
+        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+                      "The distribution of values may be incorrect.",
+                      stacklevel=2)
+
+    l = norm_cdf((a - mean) / std)
+    u = norm_cdf((b - mean) / std)
+
+    tensor.uniform_(2 * l - 1, 2 * u - 1)
+    tensor.erfinv_()
+
+    tensor.mul_(std * math.sqrt(2.))
+    tensor.add_(mean)
+
+    tensor.clamp_(min=a, max=b)
+
+    return tensor
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+    """Copy from timm"""
+    with torch.no_grad():
+        return _trunc_normal_(tensor, mean, std, a, b)
+
+def box_xyxy_to_cxcywh(x):
+    x0, y0, x1, y1 = x.unbind(-1)
+    b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
+    
+    return torch.stack(b, dim=-1)
+
+def delta2bbox(proposals,
+               deltas,
+               max_shape=None,
+               wh_ratio_clip=16 / 1000,
+               clip_border=True,
+               add_ctr_clamp=False,
+               ctr_clamp=32):
+
+    dxy = deltas[..., :2]
+    dwh = deltas[..., 2:]
+
+    # Compute width/height of each roi
+    pxy = proposals[..., :2]
+    pwh = proposals[..., 2:]
+
+    dxy_wh = pwh * dxy
+    wh_ratio_clip = torch.as_tensor(wh_ratio_clip)
+    max_ratio = torch.abs(torch.log(wh_ratio_clip)).item()
+    
+    if add_ctr_clamp:
+        dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp)
+        dwh = torch.clamp(dwh, max=max_ratio)
+    else:
+        dwh = dwh.clamp(min=-max_ratio, max=max_ratio)
+
+    gxy = pxy + dxy_wh
+    gwh = pwh * dwh.exp()
+    x1y1 = gxy - (gwh * 0.5)
+    x2y2 = gxy + (gwh * 0.5)
+    bboxes = torch.cat([x1y1, x2y2], dim=-1)
+    if clip_border and max_shape is not None:
+        bboxes[..., 0::2].clamp_(min=0).clamp_(max=max_shape[1])
+        bboxes[..., 1::2].clamp_(min=0).clamp_(max=max_shape[0])
+
+    return bboxes
+
+
 # ----------------- Customed NormLayer Ops -----------------
 class FrozenBatchNorm2d(torch.nn.Module):
     def __init__(self, n):
@@ -81,7 +153,6 @@ def get_norm(norm_type, dim):
     else:
         raise NotImplementedError
 
-
 class BasicConv(nn.Module):
     def __init__(self, 
                  in_dim,                   # in channels
@@ -147,9 +218,12 @@ class MLP(nn.Module):
         return x
 
 class FFN(nn.Module):
-    def __init__(self, d_model=256, mlp_ratio=4.0, dropout=0., act_type='relu'):
+    def __init__(self, d_model=256, mlp_ratio=4.0, dropout=0., act_type='relu', pre_norm=False):
         super().__init__()
+        # ----------- Basic parameters -----------
+        self.pre_norm = pre_norm
         self.fpn_dim = round(d_model * mlp_ratio)
+        # ----------- Network parameters -----------
         self.linear1 = nn.Linear(d_model, self.fpn_dim)
         self.activation = get_activation(act_type)
         self.dropout2 = nn.Dropout(dropout)
@@ -158,40 +232,105 @@ class FFN(nn.Module):
         self.norm = nn.LayerNorm(d_model)
 
     def forward(self, src):
-        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
-        src = src + self.dropout3(src2)
-        src = self.norm(src)
+        if self.pre_norm:
+            src = self.norm(src)
+            src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
+            src = src + self.dropout3(src2)
+        else:
+            src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
+            src = src + self.dropout3(src2)
+            src = self.norm(src)
         
         return src
     
 
-# ----------------- Basic CNN Ops -----------------
-class FrozenBatchNorm2d(torch.nn.Module):
-    def __init__(self, n):
-        super(FrozenBatchNorm2d, self).__init__()
-        self.register_buffer("weight", torch.ones(n))
-        self.register_buffer("bias", torch.zeros(n))
-        self.register_buffer("running_mean", torch.zeros(n))
-        self.register_buffer("running_var", torch.ones(n))
+# ----------------- Attention Ops -----------------
+class GlobalCrossAttention(nn.Module):
+    def __init__(
+        self,
+        dim            :int   = 256,
+        num_heads      :int   = 8,
+        qkv_bias       :bool  = True,
+        qk_scale       :float = None,
+        attn_drop      :float = 0.0,
+        proj_drop      :float = 0.0,
+        rpe_hidden_dim :int   = 512,
+        feature_stride :int   = 16,
+    ):
+        super().__init__()
+        # --------- Basic parameters ---------
+        self.dim = dim
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+        self.feature_stride = feature_stride
 
-    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
-                              missing_keys, unexpected_keys, error_msgs):
-        num_batches_tracked_key = prefix + 'num_batches_tracked'
-        if num_batches_tracked_key in state_dict:
-            del state_dict[num_batches_tracked_key]
+        # --------- Network parameters ---------
+        self.cpb_mlp1 = self.build_cpb_mlp(2, rpe_hidden_dim, num_heads)
+        self.cpb_mlp2 = self.build_cpb_mlp(2, rpe_hidden_dim, num_heads)
+        self.q = nn.Linear(dim, dim, bias=qkv_bias)
+        self.k = nn.Linear(dim, dim, bias=qkv_bias)
+        self.v = nn.Linear(dim, dim, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+        self.softmax = nn.Softmax(dim=-1)
 
-        super(FrozenBatchNorm2d, self)._load_from_state_dict(
-            state_dict, prefix, local_metadata, strict,
-            missing_keys, unexpected_keys, error_msgs)
+    def build_cpb_mlp(self, in_dim, hidden_dim, out_dim):
+        cpb_mlp = nn.Sequential(nn.Linear(in_dim, hidden_dim, bias=True),
+                                nn.ReLU(inplace=True),
+                                nn.Linear(hidden_dim, out_dim, bias=False))
+        return cpb_mlp
 
-    def forward(self, x):
-        # move reshapes to the beginning
-        # to make it fuser-friendly
-        w = self.weight.reshape(1, -1, 1, 1)
-        b = self.bias.reshape(1, -1, 1, 1)
-        rv = self.running_var.reshape(1, -1, 1, 1)
-        rm = self.running_mean.reshape(1, -1, 1, 1)
-        eps = 1e-5
-        scale = w * (rv + eps).rsqrt()
-        bias = b - rm * scale
-        return x * scale + bias
+    def forward(
+        self,
+        query,
+        reference_points,
+        k_input_flatten,
+        v_input_flatten,
+        input_spatial_shapes,
+        input_padding_mask=None,
+    ):
+        assert input_spatial_shapes.size(0) == 1, 'This is designed for single-scale decoder.'
+        h, w = input_spatial_shapes[0]
+        stride = self.feature_stride
+
+        ref_pts = torch.cat([
+            reference_points[:, :, :, :2] - reference_points[:, :, :, 2:] / 2,
+            reference_points[:, :, :, :2] + reference_points[:, :, :, 2:] / 2,
+        ], dim=-1)  # B, nQ, 1, 4
+
+        pos_x = torch.linspace(0.5, w - 0.5, w, dtype=torch.float32, device=w.device)[None, None, :, None] * stride  # 1, 1, w, 1
+        pos_y = torch.linspace(0.5, h - 0.5, h, dtype=torch.float32, device=h.device)[None, None, :, None] * stride  # 1, 1, h, 1
+
+        delta_x = ref_pts[..., 0::2] - pos_x  # B, nQ, w, 2
+        delta_y = ref_pts[..., 1::2] - pos_y  # B, nQ, h, 2
+
+        rpe_x, rpe_y = self.cpb_mlp1(delta_x), self.cpb_mlp2(delta_y)  # B, nQ, w/h, nheads
+        rpe = (rpe_x[:, :, None] + rpe_y[:, :, :, None]).flatten(2, 3) # B, nQ, h, w, nheads ->  B, nQ, h*w, nheads
+        rpe = rpe.permute(0, 3, 1, 2)
+
+        B_, N, C = k_input_flatten.shape
+        k = self.k(k_input_flatten).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+        v = self.v(v_input_flatten).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+        B_, N, C = query.shape
+        q = self.q(query).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+        q = q * self.scale
+
+        attn = q @ k.transpose(-2, -1)
+        attn += rpe
+        if input_padding_mask is not None:
+            attn += input_padding_mask[:, None, None] * -100
+
+        fmin, fmax = torch.finfo(attn.dtype).min, torch.finfo(attn.dtype).max
+        torch.clip_(attn, min=fmin, max=fmax)
+
+        attn = self.softmax(attn)
+        attn = self.attn_drop(attn)
+        x = attn @ v
+
+        x = x.transpose(1, 2).reshape(B_, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+
+        return x

+ 254 - 95
models/detectors/rtpdetr/basic_modules/transformer.py

@@ -4,12 +4,14 @@ import copy
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-from torch.nn.init import constant_, xavier_uniform_
+import torch.utils.checkpoint as checkpoint
 
 try:
-    from .basic import get_activation, MLP, FFN
+    from .basic import FFN, GlobalCrossAttention
+    from .basic import trunc_normal_
 except:
-    from  basic import get_activation, MLP, FFN
+    from  basic import FFN, GlobalCrossAttention
+    from  basic import trunc_normal_
 
 
 def get_clones(module, N):
@@ -152,137 +154,294 @@ class TransformerEncoder(nn.Module):
 
         return src
 
-## Transformer Decoder layer
-class PlainTransformerDecoderLayer(nn.Module):
+## PlainDETR's Decoder layer
+class GlobalDecoderLayer(nn.Module):
     def __init__(self,
-                 d_model     :int   = 256,
-                 num_heads   :int   = 8,
-                 num_levels  :int   = 3,
-                 num_points  :int   = 4,
-                 mlp_ratio   :float = 4.0,
-                 dropout     :float = 0.1,
-                 act_type    :str   = "relu",
-                 ):
+                 d_model    :int   = 256,
+                 num_heads  :int   = 8,
+                 mlp_ratio  :float = 4.0,
+                 dropout    :float = 0.1,
+                 act_type   :str   = "relu",
+                 pre_norm   :bool  = False,
+                 rpe_hidden_dim :int = 512,
+                 feature_stride :int = 16,
+                 ) -> None:
         super().__init__()
-        # ----------- Basic parameters -----------
+        # ------------ Basic parameters ------------
         self.d_model = d_model
         self.num_heads = num_heads
-        self.num_levels = num_levels
-        self.num_points = num_points
+        self.rpe_hidden_dim = rpe_hidden_dim
         self.mlp_ratio = mlp_ratio
-        self.dropout = dropout
         self.act_type = act_type
-        # ---------------- Network parameters ----------------
+        self.pre_norm = pre_norm
+
+        # ------------ Network parameters ------------
         ## Multi-head Self-Attn
-        self.self_attn  = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
+        self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
         self.dropout1 = nn.Dropout(dropout)
         self.norm1 = nn.LayerNorm(d_model)
-        ## CrossAttention
-        self.cross_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
+
+        ## Box-reparam Global Cross-Attn
+        self.cross_attn = GlobalCrossAttention(d_model, num_heads, rpe_hidden_dim=rpe_hidden_dim, feature_stride=feature_stride)
         self.dropout2 = nn.Dropout(dropout)
         self.norm2 = nn.LayerNorm(d_model)
+
         ## FFN
-        self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
+        self.ffn = FFN(d_model, mlp_ratio, dropout, act_type, pre_norm)
 
-    def with_pos_embed(self, tensor, pos):
+    @staticmethod
+    def with_pos_embed(tensor, pos):
         return tensor if pos is None else tensor + pos
 
-    def forward(self,
-                tgt,
-                reference_points,
-                memory,
-                memory_spatial_shapes,
-                attn_mask=None,
-                memory_mask=None,
-                query_pos_embed=None):
-        # ---------------- MSHA for Object Query -----------------
-        q = k = self.with_pos_embed(tgt, query_pos_embed)
-        if attn_mask is not None:
-            attn_mask = torch.where(
-                attn_mask.bool(),
-                torch.zeros(attn_mask.shape, dtype=tgt.dtype, device=attn_mask.device),
-                torch.full(attn_mask.shape, float("-inf"), dtype=tgt.dtype, device=attn_mask.device))
-        tgt2 = self.self_attn(q, k, value=tgt)[0]
-        tgt = tgt + self.dropout1(tgt2)
+    def forward_pre_norm(self,
+                         tgt,
+                         query_pos,
+                         reference_points,
+                         src,
+                         src_pos_embed,
+                         src_spatial_shapes,
+                         src_padding_mask=None,
+                         self_attn_mask=None,
+                         ):
+        # ----------- Multi-head self attention -----------
+        tgt1 = self.norm1(tgt)
+        q = k = self.with_pos_embed(tgt1, query_pos)
+        tgt1 = self.self_attn(q.transpose(0, 1),        # [B, N, C] -> [N, B, C], batch_first = False
+                              k.transpose(0, 1),        # [B, N, C] -> [N, B, C], batch_first = False
+                              tgt1.transpose(0, 1),     # [B, N, C] -> [N, B, C], batch_first = False
+                              attn_mask=self_attn_mask,
+                              )[0].transpose(0, 1)      # [N, B, C] -> [B, N, C]
+        tgt = tgt + self.dropout1(tgt1)
+
+        # ----------- Global corss attention -----------
+        tgt1 = self.norm2(tgt)
+        tgt1 = self.cross_attn(self.with_pos_embed(tgt1, query_pos),
+                               reference_points,
+                               self.with_pos_embed(src, src_pos_embed),
+                               src,
+                               src_spatial_shapes,
+                               src_padding_mask,
+                               )
+        tgt = tgt + self.dropout2(tgt1)
+
+        # ----------- FeedForward Network -----------
+        tgt = self.ffn(tgt)
+
+        return tgt
+
+    def forward_post_norm(self,
+                          tgt,
+                          query_pos,
+                          reference_points,
+                          src,
+                          src_pos_embed,
+                          src_spatial_shapes,
+                          src_padding_mask=None,
+                          self_attn_mask=None,
+                          ):
+        # ----------- Multi-head self attention -----------
+        q = k = self.with_pos_embed(tgt, query_pos)
+        tgt1 = self.self_attn(q.transpose(0, 1),        # [B, N, C] -> [N, B, C], batch_first = False
+                              k.transpose(0, 1),        # [B, N, C] -> [N, B, C], batch_first = False
+                              tgt.transpose(0, 1),     # [B, N, C] -> [N, B, C], batch_first = False
+                              attn_mask=self_attn_mask,
+                              )[0].transpose(0, 1)      # [N, B, C] -> [B, N, C]
+        tgt = tgt + self.dropout1(tgt1)
         tgt = self.norm1(tgt)
 
-        # ---------------- CMHA for Object Query and Image-feature -----------------
-        tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos_embed),
+        # ----------- Global corss attention -----------
+        tgt1 = self.cross_attn(self.with_pos_embed(tgt, query_pos),
                                reference_points,
-                               memory,
-                               memory_spatial_shapes,
-                               memory_mask)
-        tgt = tgt + self.dropout2(tgt2)
+                               self.with_pos_embed(src, src_pos_embed),
+                               src,
+                               src_spatial_shapes,
+                               src_padding_mask,
+                               )
+        tgt = tgt + self.dropout2(tgt1)
         tgt = self.norm2(tgt)
 
-        # ---------------- FeedForward Network -----------------
+        # ----------- FeedForward Network -----------
         tgt = self.ffn(tgt)
 
         return tgt
 
-## Transformer Decoder
-class PlainTransformerDecoder(nn.Module):
+    def forward(self,
+                tgt,
+                query_pos,
+                reference_points,
+                src,
+                src_pos_embed,
+                src_spatial_shapes,
+                src_padding_mask=None,
+                self_attn_mask=None,
+                ):
+        if self.pre_norm:
+            return self.forward_pre_norm(tgt, query_pos, reference_points, src, src_pos_embed, src_spatial_shapes,
+                                         src_padding_mask, self_attn_mask)
+        else:
+            return self.forward_post_norm(tgt, query_pos, reference_points, src, src_pos_embed, src_spatial_shapes,
+                                          src_padding_mask, self_attn_mask)
+
+## PlainDETR's Decoder
+class GlobalDecoder(nn.Module):
     def __init__(self,
-                 d_model        :int   = 256,
-                 num_heads      :int   = 8,
-                 num_layers     :int   = 1,
-                 num_levels     :int   = 3,
-                 num_points     :int   = 4,
-                 mlp_ratio      :float = 4.0,
-                 dropout        :float = 0.1,
-                 act_type       :str   = "relu",
+                 # Decoder layer params
+                 d_model    :int   = 256,
+                 num_heads  :int   = 8,
+                 mlp_ratio  :float = 4.0,
+                 dropout    :float = 0.1,
+                 act_type   :str   = "relu",
+                 pre_norm   :bool  = False,
+                 rpe_hidden_dim :int = 512,
+                 feature_stride :int = 16,
+                 num_layers     :int = 6,
+                 # Decoder params
                  return_intermediate :bool = False,
+                 use_checkpoint      :bool = False,
                  ):
         super().__init__()
-        # ----------- Basic parameters -----------
+        # ------------ Basic parameters ------------
         self.d_model = d_model
         self.num_heads = num_heads
-        self.num_layers = num_layers
+        self.rpe_hidden_dim = rpe_hidden_dim
         self.mlp_ratio = mlp_ratio
-        self.dropout = dropout
         self.act_type = act_type
-        self.pos_embed = None
-        # ----------- Network parameters -----------
-        self.decoder_layers = get_clones(
-            TransformerDecoderLayer(d_model, num_heads, num_levels, num_points, mlp_ratio, dropout, act_type), num_layers)
         self.num_layers = num_layers
         self.return_intermediate = return_intermediate
+        self.use_checkpoint = use_checkpoint
+
+        # ------------ Network parameters ------------
+        decoder_layer = GlobalDecoderLayer(
+            d_model, num_heads, mlp_ratio, dropout, act_type, pre_norm, rpe_hidden_dim, feature_stride,)
+        self.layers = get_clones(decoder_layer, num_layers)
+        self.bbox_embed = None
+        self.class_embed = None
+
+        if pre_norm:
+            self.final_layer_norm = nn.LayerNorm(d_model)
+        else:
+            self.final_layer_norm = None
+
+    def _reset_parameters(self):            
+        # stolen from Swin Transformer
+        def _init_weights(m):
+            if isinstance(m, nn.Linear):
+                trunc_normal_(m.weight, std=0.02)
+                if isinstance(m, nn.Linear) and m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.LayerNorm):
+                nn.init.constant_(m.bias, 0)
+                nn.init.constant_(m.weight, 1.0)
+
+        self.apply(_init_weights)
+
+    def inverse_sigmoid(self, x, eps=1e-5):
+        x = x.clamp(min=0, max=1)
+        x1 = x.clamp(min=eps)
+        x2 = (1 - x).clamp(min=eps)
+
+        return torch.log(x1 / x2)
+
+    def box_xyxy_to_cxcywh(self, x):
+        x0, y0, x1, y1 = x.unbind(-1)
+        b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
+        
+        return torch.stack(b, dim=-1)
+
+    def delta2bbox(self, proposals,
+                   deltas,
+                   max_shape=None,
+                   wh_ratio_clip=16 / 1000,
+                   clip_border=True,
+                   add_ctr_clamp=False,
+                   ctr_clamp=32):
+
+        dxy = deltas[..., :2]
+        dwh = deltas[..., 2:]
+
+        # Compute width/height of each roi
+        pxy = proposals[..., :2]
+        pwh = proposals[..., 2:]
+
+        dxy_wh = pwh * dxy
+        wh_ratio_clip = torch.as_tensor(wh_ratio_clip)
+        max_ratio = torch.abs(torch.log(wh_ratio_clip)).item()
+        
+        if add_ctr_clamp:
+            dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp)
+            dwh = torch.clamp(dwh, max=max_ratio)
+        else:
+            dwh = dwh.clamp(min=-max_ratio, max=max_ratio)
+
+        gxy = pxy + dxy_wh
+        gwh = pwh * dwh.exp()
+        x1y1 = gxy - (gwh * 0.5)
+        x2y2 = gxy + (gwh * 0.5)
+        bboxes = torch.cat([x1y1, x2y2], dim=-1)
+        if clip_border and max_shape is not None:
+            bboxes[..., 0::2].clamp_(min=0).clamp_(max=max_shape[1])
+            bboxes[..., 1::2].clamp_(min=0).clamp_(max=max_shape[0])
+
+        return bboxes
 
     def forward(self,
                 tgt,
-                ref_points_unact,
-                memory,
-                memory_spatial_shapes,
-                bbox_head,
-                score_head,
-                query_pos_head,
-                attn_mask=None,
-                memory_mask=None):
+                reference_points,
+                src,
+                src_pos_embed,
+                src_spatial_shapes,
+                query_pos=None,
+                src_padding_mask=None,
+                self_attn_mask=None,
+                max_shape=None,
+                ):
         output = tgt
-        dec_out_bboxes = []
-        dec_out_logits = []
-        ref_points_detach = F.sigmoid(ref_points_unact)
-        for i, layer in enumerate(self.decoder_layers):
-            ref_points_input = ref_points_detach.unsqueeze(2)
-            query_pos_embed = query_pos_head(ref_points_detach)
-
-            output = layer(output, ref_points_input, memory,
-                           memory_spatial_shapes, attn_mask,
-                           memory_mask, query_pos_embed)
-
-            inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
-                ref_points_detach))
-
-            dec_out_logits.append(score_head[i](output))
-            if i == 0:
-                dec_out_bboxes.append(inter_ref_bbox)
+
+        intermediate = []
+        intermediate_reference_points = []
+        for lid, layer in enumerate(self.layers):
+            reference_points_input = reference_points[:, :, None]
+            if self.use_checkpoint:
+                output = checkpoint.checkpoint(
+                    layer,
+                    output,
+                    query_pos,
+                    reference_points_input,
+                    src,
+                    src_pos_embed,
+                    src_spatial_shapes,
+                    src_padding_mask,
+                    self_attn_mask,
+                )
             else:
-                dec_out_bboxes.append(
-                    F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
-                        ref_points)))
+                output = layer(
+                    output,
+                    query_pos,
+                    reference_points_input,
+                    src,
+                    src_pos_embed,
+                    src_spatial_shapes,
+                    src_padding_mask,
+                    self_attn_mask,
+                )
+
+            if self.final_layer_norm is not None:
+                output_after_norm = self.final_layer_norm(output)
+            else:
+                output_after_norm = output
+
+            # hack implementation for iterative bounding box refinement
+            if self.bbox_embed is not None:
+                tmp = self.bbox_embed[lid](output_after_norm)
+                new_reference_points = self.box_xyxy_to_cxcywh(
+                    self.delta2bbox(reference_points, tmp, max_shape)) 
+                reference_points = new_reference_points.detach()
 
-            ref_points = inter_ref_bbox
-            ref_points_detach = inter_ref_bbox.detach()
+            if self.return_intermediate:
+                intermediate.append(output_after_norm)
+                intermediate_reference_points.append(new_reference_points)
 
-        return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)
+        if self.return_intermediate:
+            return torch.stack(intermediate), torch.stack(intermediate_reference_points)
 
+        return output_after_norm, reference_points

+ 265 - 63
models/detectors/rtpdetr/rtpdetr.py

@@ -1,10 +1,15 @@
+import math
 import torch
 import torch.nn as nn
 
 try:
+    from .basic_modules.basic import MLP
+    from .basic_modules.transformer import get_clones
     from .rtpdetr_encoder import build_image_encoder
     from .rtpdetr_decoder import build_transformer
 except:
+    from  basic_modules.basic import MLP
+    from  basic_modules.transformer import get_clones
     from  rtpdetr_encoder import build_image_encoder
     from  rtpdetr_decoder import build_transformer
 
@@ -21,6 +26,9 @@ class RT_PDETR(nn.Module):
                  ):
         super().__init__()
         # ----------- Basic setting -----------
+        self.num_queries_one2one = cfg['num_queries_one2one']
+        self.num_queries_one2many = cfg['num_queries_one2many']
+        self.num_queries = self.num_queries_one2one + self.num_queries_one2many
         self.num_classes = num_classes
         self.num_topk = topk
         self.conf_thresh = conf_thresh
@@ -30,12 +38,78 @@ class RT_PDETR(nn.Module):
         # ----------- Network setting -----------
         ## Image encoder
         self.image_encoder = build_image_encoder(cfg)
-        self.feat_dim = self.image_encoder.fpn_dims[-1]
 
-        ## Detect decoder
-        self.detect_decoder = build_transformer(cfg, self.feat_dim, num_classes, return_intermediate=self.training)
+        ## Transformer Decoder
+        self.transformer = build_transformer(cfg, return_intermediate=self.training)
+        self.query_embed = nn.Embedding(self.num_queries, cfg['hidden_dim'])
+
+        ## Detect Head
+        class_embed = nn.Linear(cfg['hidden_dim'], num_classes)
+        bbox_embed = MLP(cfg['hidden_dim'], cfg['hidden_dim'], 4, 3)
+
+        prior_prob = 0.01
+        bias_value = -math.log((1 - prior_prob) / prior_prob)
+        class_embed.bias.data = torch.ones(num_classes) * bias_value
+        nn.init.constant_(bbox_embed.layers[-1].weight.data, 0)
+        nn.init.constant_(bbox_embed.layers[-1].bias.data, 0)
+
+        self.class_embed = get_clones(class_embed, cfg['de_num_layers'] + 1)
+        self.bbox_embed  = get_clones(bbox_embed, cfg['de_num_layers'] + 1)
+        nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
+
+        self.transformer.decoder.bbox_embed = self.bbox_embed
+        self.transformer.decoder.class_embed = self.class_embed
+
+    def pos2posembed(self, d_model, pos, temperature=10000):
+        scale = 2 * torch.pi
+        num_pos_feats = d_model // 2
+
+        dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
+        dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
+        dim_t = temperature ** (2 * dim_t_)
+
+        # Position embedding for XY
+        x_embed = pos[..., 0] * scale
+        y_embed = pos[..., 1] * scale
+        pos_x = x_embed[..., None] / dim_t
+        pos_y = y_embed[..., None] / dim_t
+        pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
+        pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
+        posemb = torch.cat((pos_y, pos_x), dim=-1)
+        
+        # Position embedding for WH
+        if pos.size(-1) == 4:
+            w_embed = pos[..., 2] * scale
+            h_embed = pos[..., 3] * scale
+            pos_w = w_embed[..., None] / dim_t
+            pos_h = h_embed[..., None] / dim_t
+            pos_w = torch.stack((pos_w[..., 0::2].sin(), pos_w[..., 1::2].cos()), dim=-1).flatten(-2)
+            pos_h = torch.stack((pos_h[..., 0::2].sin(), pos_h[..., 1::2].cos()), dim=-1).flatten(-2)
+            posemb = torch.cat((posemb, pos_w, pos_h), dim=-1)
+        
+        return posemb
+
+    def get_posembed(self, d_model, mask, temperature=10000):
+        not_mask = ~mask
+        # [B, H, W]
+        y_embed = not_mask.cumsum(1, dtype=torch.float32)
+        x_embed = not_mask.cumsum(2, dtype=torch.float32)
+
+        y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + 1e-6)
+        x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + 1e-6)
+    
+        # [H, W] -> [B, H, W, 2]
+        pos = torch.stack([x_embed, y_embed], dim=-1)
+
+        # [B, H, W, C]
+        pos_embed = self.pos2posembed(d_model, pos, temperature)
+        pos_embed = pos_embed.permute(0, 3, 1, 2)
+        
+        return pos_embed
 
     def post_process(self, box_pred, cls_pred):
+        cls_pred = cls_pred[0]
+        box_pred = box_pred[0]
         if self.no_multi_labels:
             # [M,]
             scores, labels = torch.max(cls_pred.sigmoid(), dim=1)
@@ -57,11 +131,10 @@ class RT_PDETR(nn.Module):
             topk_labels = labels[topk_idxs]
             topk_bboxes = box_pred[topk_idxs]
 
-            return topk_bboxes, topk_scores, topk_labels
         else:
             # Top-k select
-            cls_pred = cls_pred[0].flatten().sigmoid_()
-            box_pred = box_pred[0]
+            cls_pred = cls_pred.flatten().sigmoid_()
+            box_pred = box_pred
 
             # Keep top k top scoring indices only.
             num_topk = min(self.num_topk, box_pred.size(0))
@@ -84,77 +157,206 @@ class RT_PDETR(nn.Module):
 
         return topk_bboxes, topk_scores, topk_labels
     
-    def forward(self, x, targets=None):
+    @torch.jit.unused
+    def _set_aux_loss(self, outputs_class, outputs_coord, outputs_coord_old, outputs_deltas):
+        # this is a workaround to make torchscript happy, as torchscript
+        # doesn't support dictionary with non-homogeneous values, such
+        # as a dict having both a Tensor and a list.
+        return [
+            {"pred_logits": a, "pred_boxes": b, "pred_boxes_old": c, "pred_deltas": d, }
+            for a, b, c, d in zip(outputs_class[:-1], outputs_coord[:-1], outputs_coord_old[:-1], outputs_deltas[:-1])
+        ]
+
+    def inference_single_image(self, x):
         # ----------- Image Encoder -----------
-        pyramid_feats = self.image_encoder(x)
+        src = self.image_encoder(x)
+
+        # ----------- Prepare inputs for Transformer -----------
+        mask = torch.zeros([src.shape[0], src.shape[2], src.shape[3]]).bool().to(src.device)
+        pos_embed = self.get_posembed(src.shape[1], mask)
+        self_attn_mask = None
+        query_embeds = self.query_embed.weight[:self.num_queries_one2one]
+
+        # -----------Transformer -----------
+        (
+            hs,
+            init_reference,
+            inter_references,
+            _,
+            _,
+            _,
+            _,
+            max_shape
+        ) = self.transformer(src, mask, pos_embed, query_embeds, self_attn_mask)
+
+        # ----------- Process outputs -----------
+        outputs_classes_one2one = []
+        outputs_coords_one2one = []
+        outputs_deltas_one2one = []
 
-        # ----------- Transformer -----------
-        transformer_outputs = self.detect_decoder(pyramid_feats, targets)
+        for lid in range(hs.shape[0]):
+            if lid == 0:
+                reference = init_reference
+            else:
+                reference = inter_references[lid - 1]
+            outputs_class = self.class_embed[lid](hs[lid])
+            tmp = self.bbox_embed[lid](hs[lid])
+            outputs_coord = self.transformer.decoder.delta2bbox(reference, tmp, max_shape)  # xyxy
+
+            outputs_classes_one2one.append(outputs_class[:, :self.num_queries_one2one])
+            outputs_coords_one2one.append(outputs_coord[:, :self.num_queries_one2one])
+            outputs_deltas_one2one.append(tmp[:, :self.num_queries_one2one])
+
+        outputs_classes_one2one = torch.stack(outputs_classes_one2one)
+        outputs_coords_one2one = torch.stack(outputs_coords_one2one)
+
+        # ------------ Post process ------------
+        cls_pred = outputs_classes_one2one[-1]
+        box_pred = outputs_coords_one2one[-1]
+        
+        # post-process
+        bboxes, scores, labels = self.post_process(box_pred, cls_pred)
+
+        outputs = {
+            "scores": scores.cpu().numpy(),
+            "labels": labels.cpu().numpy(),
+            "bboxes": bboxes.cpu().numpy(),
+        }
 
+        return outputs
+        
+    def forward(self, x):
+        if not self.training:
+            return self.inference_single_image(x)
+
+        # ----------- Image Encoder -----------
+        src = self.image_encoder(x)
+
+        # ----------- Prepare inputs for Transformer -----------
+        mask = torch.zeros([src.shape[0], src.shape[2], src.shape[3]]).bool().to(src.device)
+        pos_embed = self.get_posembed(src.shape[1], mask)
         if self.training:
-            return transformer_outputs
+            self_attn_mask = torch.zeros(
+                [self.num_queries, self.num_queries, ]).bool().to(src.device)
+            self_attn_mask[self.num_queries_one2one:, 0: self.num_queries_one2one, ] = True
+            self_attn_mask[0: self.num_queries_one2one, self.num_queries_one2one:, ] = True
+            query_embeds = self.query_embed.weight
         else:
-            pred_boxes, pred_logits = transformer_outputs[0], transformer_outputs[1]
-            box_preds = pred_boxes[-1]
-            cls_preds = pred_logits[-1]
-            
-            # post-process
-            bboxes, scores, labels = self.post_process(box_preds, cls_preds)
-
-            outputs = {
-                "scores": scores.cpu().numpy(),
-                "labels": labels.cpu().numpy(),
-                "bboxes": bboxes.cpu().numpy(),
-            }
-
-            return outputs
-        
+            self_attn_mask = None
+            query_embeds = self.query_embed.weight[:self.num_queries_one2one]
+
+        # -----------Transformer -----------
+        (
+            hs,
+            init_reference,
+            inter_references,
+            enc_outputs_class,
+            enc_outputs_coord_unact,
+            enc_outputs_delta,
+            output_proposals,
+            max_shape
+        ) = self.transformer(src, mask, pos_embed, query_embeds, self_attn_mask)
+
+        # ----------- Process outputs -----------
+        outputs_classes_one2one = []
+        outputs_coords_one2one = []
+        outputs_classes_one2many = []
+        outputs_coords_one2many = []
+
+        outputs_coords_old_one2one = []
+        outputs_deltas_one2one = []
+        outputs_coords_old_one2many = []
+        outputs_deltas_one2many = []
+
+        for lid in range(hs.shape[0]):
+            if lid == 0:
+                reference = init_reference
+            else:
+                reference = inter_references[lid - 1]
+            outputs_class = self.class_embed[lid](hs[lid])
+            tmp = self.bbox_embed[lid](hs[lid])
+            outputs_coord = self.transformer.decoder.box_xyxy_to_cxcywh(
+                self.transformer.decoder.delta2bbox(reference, tmp, max_shape))
+
+            outputs_classes_one2one.append(outputs_class[:, 0: self.num_queries_one2one])
+            outputs_classes_one2many.append(outputs_class[:, self.num_queries_one2one:])
+
+            outputs_coords_one2one.append(outputs_coord[:, 0: self.num_queries_one2one])
+            outputs_coords_one2many.append(outputs_coord[:, self.num_queries_one2one:])
+
+            outputs_coords_old_one2one.append(reference[:, :self.num_queries_one2one])
+            outputs_coords_old_one2many.append(reference[:, self.num_queries_one2one:])
+            outputs_deltas_one2one.append(tmp[:, :self.num_queries_one2one])
+            outputs_deltas_one2many.append(tmp[:, self.num_queries_one2one:])
+
+        outputs_classes_one2one = torch.stack(outputs_classes_one2one)
+        outputs_coords_one2one = torch.stack(outputs_coords_one2one)
+
+        outputs_classes_one2many = torch.stack(outputs_classes_one2many)
+        outputs_coords_one2many = torch.stack(outputs_coords_one2many)
+
+        out = {
+            "pred_logits": outputs_classes_one2one[-1],
+            "pred_boxes": outputs_coords_one2one[-1],
+            "pred_logits_one2many": outputs_classes_one2many[-1],
+            "pred_boxes_one2many": outputs_coords_one2many[-1],
+
+            "pred_boxes_old": outputs_coords_old_one2one[-1],
+            "pred_deltas": outputs_deltas_one2one[-1],
+            "pred_boxes_old_one2many": outputs_coords_old_one2many[-1],
+            "pred_deltas_one2many": outputs_deltas_one2many[-1],
+        }
+
+        out["aux_outputs"] = self._set_aux_loss(
+            outputs_classes_one2one, outputs_coords_one2one, outputs_coords_old_one2one, outputs_deltas_one2one
+        )
+        out["aux_outputs_one2many"] = self._set_aux_loss(
+            outputs_classes_one2many, outputs_coords_one2many, outputs_coords_old_one2many, outputs_deltas_one2many
+        )
+
+        out["enc_outputs"] = {
+            "pred_logits": enc_outputs_class,
+            "pred_boxes": enc_outputs_coord_unact,
+            "pred_boxes_old": output_proposals,
+            "pred_deltas": enc_outputs_delta,
+        }
+
+        return out
+                
 
 if __name__ == '__main__':
     import time
     from thop import profile
-    from loss import build_criterion
+    # from loss import build_criterion
 
     # Model config
     cfg = {
         'width': 1.0,
         'depth': 1.0,
-        'out_stride': [8, 16, 32],
+        'max_stride': 32,
+        'out_stride': 16,
         # Image Encoder - Backbone
-        'backbone': 'resnet18',
-        'backbone_norm': 'BN',
-        'res5_dilation': False,
+        'backbone': 'resnet50',
+        'backbone_norm': 'FrozeBN',
         'pretrained': True,
-        'pretrained_weight': 'imagenet1k_v1',
-        # Image Encoder - FPN
-        'fpn': 'hybrid_encoder',
-        'fpn_act': 'silu',
-        'fpn_norm': 'BN',
-        'fpn_depthwise': False,
+        'freeze_at': 0,
+        'freeze_stem_only': False,
         'hidden_dim': 256,
-        'en_num_heads': 8,
-        'en_num_layers': 1,
-        'en_mlp_ratio': 4.0,
-        'en_dropout': 0.1,
-        'pe_temperature': 10000.,
-        'en_act': 'gelu',
         # Transformer Decoder
-        'transformer': 'rtdetr_transformer',
+        'transformer': 'plain_detr_transformer',
         'hidden_dim': 256,
         'de_num_heads': 8,
         'de_num_layers': 6,
         'de_mlp_ratio': 4.0,
-        'de_dropout': 0.0,
+        'de_dropout': 0.1,
         'de_act': 'gelu',
-        'de_num_points': 4,
-        'num_queries': 300,
-        'learnt_init_query': False,
-        'pe_temperature': 10000.,
-        'dn_num_denoising': 100,
-        'dn_label_noise_ratio': 0.5,
-        'dn_box_noise_scale': 1,
-        # Head
-        'det_head': 'dino_head',
+        'de_pre_norm': True,
+        'rpe_hidden_dim': 512,
+        'use_checkpoint': False,
+        'proposal_feature_levels': 3,
+        'proposal_tgt_strides': [8, 16, 32],
+        'num_queries_one2one': 300,
+        'num_queries_one2many': 300,
         # Matcher
         'matcher_hpy': {'cost_class': 2.0,
                         'cost_bbox': 5.0,
@@ -175,22 +377,22 @@ if __name__ == '__main__':
     }] * bs
 
     # Create model
-    model = RT_DETR(cfg, num_classes=80)
+    model = RT_PDETR(cfg, num_classes=80)
     model.train()
 
-    # Create criterion
-    criterion = build_criterion(cfg, num_classes=80)
-
     # Model inference
     t0 = time.time()
-    outputs = model(image, targets)
+    outputs = model(image)
     t1 = time.time()
     print('Infer time: ', t1 - t0)
 
-    # Compute loss
-    loss = criterion(*outputs, targets)
-    for k in loss.keys():
-        print("{} : {}".format(k, loss[k].item()))
+    # # Create criterion
+    # criterion = build_criterion(cfg, num_classes=80)
+
+    # # Compute loss
+    # loss = criterion(*outputs, targets)
+    # for k in loss.keys():
+    #     print("{} : {}".format(k, loss[k].item()))
 
     print('==============================')
     model.eval()

+ 321 - 282
models/detectors/rtpdetr/rtpdetr_decoder.py

@@ -2,295 +2,318 @@ import math
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-from torch.nn.init import constant_, xavier_uniform_, uniform_, normal_
-from typing import List
 
 try:
-    from .basic_modules.basic import BasicConv, MLP
-    from .basic_modules.transformer import PlainTransformerDecoder
+    from .basic_modules.basic import LayerNorm2D
+    from .basic_modules.transformer import GlobalDecoder
 except:
-    from  basic_modules.basic import BasicConv, MLP
-    from  basic_modules.transformer import PlainTransformerDecoder
+    from  basic_modules.basic import LayerNorm2D
+    from  basic_modules.transformer import GlobalDecoder
 
-
-def build_transformer(cfg, in_dims, num_classes, return_intermediate=False):
+def build_transformer(cfg, return_intermediate=False):
     if cfg['transformer'] == 'plain_detr_transformer':
-        return PlainDETRTransformer(in_dims             = in_dims,
-                                 hidden_dim          = cfg['hidden_dim'],
-                                 strides             = cfg['out_stride'],
-                                 num_classes         = num_classes,
-                                 num_queries         = cfg['num_queries'],
-                                 pos_embed_type      = 'sine',
-                                 num_heads           = cfg['de_num_heads'],
-                                 num_layers          = cfg['de_num_layers'],
-                                 num_levels          = len(cfg['out_stride']),
-                                 num_points          = cfg['de_num_points'],
-                                 mlp_ratio           = cfg['de_mlp_ratio'],
-                                 dropout             = cfg['de_dropout'],
-                                 act_type            = cfg['de_act'],
-                                 return_intermediate = return_intermediate,
-                                 num_denoising       = cfg['dn_num_denoising'],
-                                 label_noise_ratio   = cfg['dn_label_noise_ratio'],
-                                 box_noise_scale     = cfg['dn_box_noise_scale'],
-                                 learnt_init_query   = cfg['learnt_init_query'],
-                                 )
+        return PlainDETRTransformer(d_model             = cfg['hidden_dim'],
+                                    num_heads           = cfg['de_num_heads'],
+                                    mlp_ratio           = cfg['de_mlp_ratio'],
+                                    dropout             = cfg['de_dropout'],
+                                    act_type            = cfg['de_act'],
+                                    pre_norm            = cfg['de_pre_norm'],
+                                    rpe_hidden_dim      = cfg['rpe_hidden_dim'],
+                                    feature_stride      = cfg['out_stride'],
+                                    num_layers          = cfg['de_num_layers'],
+                                    return_intermediate = return_intermediate,
+                                    use_checkpoint      = cfg['use_checkpoint'],
+                                    num_queries_one2one = cfg['num_queries_one2one'],
+                                    num_queries_one2many = cfg['num_queries_one2many'],
+                                    proposal_feature_levels = cfg['proposal_feature_levels'],
+                                    proposal_in_stride      = cfg['out_stride'],
+                                    proposal_tgt_strides    = cfg['proposal_tgt_strides'],
+                                    )
 
 
 # ----------------- Dencoder for Detection task -----------------
-## RTDETR's Transformer for Detection task
+## PlainDETR's Transformer for Detection task
 class PlainDETRTransformer(nn.Module):
     def __init__(self,
-                 # basic parameters
-                 in_dims        :List = [256, 512, 1024],
-                 hidden_dim     :int  = 256,
-                 strides        :List = [8, 16, 32],
-                 num_classes    :int  = 80,
-                 num_queries    :int  = 300,
-                 pos_embed_type :str  = 'sine',
-                 # transformer parameters
+                 # Decoder layer params
+                 d_model        :int   = 256,
                  num_heads      :int   = 8,
-                 num_layers     :int   = 1,
-                 num_levels     :int   = 3,
-                 num_points     :int   = 4,
                  mlp_ratio      :float = 4.0,
                  dropout        :float = 0.1,
                  act_type       :str   = "relu",
+                 pre_norm       :bool  = False,
+                 rpe_hidden_dim :int   = 512,
+                 feature_stride :int   = 16,
+                 num_layers     :int   = 6,
+                 # Decoder params
                  return_intermediate :bool = False,
-                 # Denoising parameters
-                 num_denoising       :int  = 100,
-                 label_noise_ratio   :float = 0.5,
-                 box_noise_scale     :float = 1.0,
-                 learnt_init_query   :bool  = True,
+                 use_checkpoint      :bool = False,
+                 num_queries_one2one :int = 300,
+                 num_queries_one2many :int = 1500,
+                 proposal_feature_levels :int = 3,
+                 proposal_in_stride      :int = 16,
+                 proposal_tgt_strides    :int = [8, 16, 32],
                  ):
         super().__init__()
-        # --------------- Basic setting ---------------
-        ## Basic parameters
-        self.in_dims = in_dims
-        self.strides = strides
-        self.num_queries = num_queries
-        self.pos_embed_type = pos_embed_type
-        self.num_classes = num_classes
-        self.eps = 1e-2
-        ## Transformer parameters
-        self.num_heads  = num_heads
+        # ------------ Basic setting ------------
+        ## Model
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.rpe_hidden_dim = rpe_hidden_dim
+        self.mlp_ratio = mlp_ratio
+        self.act_type = act_type
         self.num_layers = num_layers
-        self.num_levels = num_levels
-        self.num_points = num_points
-        self.mlp_ratio  = mlp_ratio
-        self.dropout    = dropout
-        self.act_type   = act_type
         self.return_intermediate = return_intermediate
-        ## Denoising parameters
-        self.num_denoising = num_denoising
-        self.label_noise_ratio = label_noise_ratio
-        self.box_noise_scale = box_noise_scale
-        self.learnt_init_query = learnt_init_query
+        ## Trick
+        self.use_checkpoint = use_checkpoint
+        self.num_queries_one2one = num_queries_one2one
+        self.num_queries_one2many = num_queries_one2many
+        self.proposal_feature_levels = proposal_feature_levels
+        self.proposal_tgt_strides = proposal_tgt_strides
+        self.proposal_in_stride = proposal_in_stride
+        self.proposal_min_size = 50
 
         # --------------- Network setting ---------------
-        ## Input proj layers
-        self.input_proj_layers = nn.ModuleList(
-            BasicConv(in_dims[i], hidden_dim, kernel_size=1, act_type=None, norm_type="BN")
-            for i in range(num_levels)
-        )
-
-        ## Deformable transformer decoder
-        self.decoder = PlainTransformerDecoder(
-                                    d_model    = hidden_dim,
-                                    num_heads  = num_heads,
-                                    num_layers = num_layers,
-                                    num_levels = num_levels,
-                                    num_points = num_points,
-                                    mlp_ratio  = mlp_ratio,
-                                    dropout    = dropout,
-                                    act_type   = act_type,
-                                    return_intermediate = return_intermediate
-                                    )
+        ## Global Decoder
+        self.decoder = GlobalDecoder(d_model, num_heads, mlp_ratio, dropout, act_type, pre_norm,
+                                     rpe_hidden_dim, feature_stride, num_layers, return_intermediate,
+                                     use_checkpoint,)
         
-        ## Detection head for Encoder
-        self.enc_output = nn.Sequential(
-            nn.Linear(hidden_dim, hidden_dim),
-            nn.LayerNorm(hidden_dim)
-            )
-        self.enc_class_head = nn.Linear(hidden_dim, num_classes)
-        self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
-
-        ## Detection head for Decoder
-        self.dec_class_head = nn.ModuleList([
-            nn.Linear(hidden_dim, num_classes)
-            for _ in range(num_layers)
-        ])
-        self.dec_bbox_head = nn.ModuleList([
-            MLP(hidden_dim, hidden_dim, 4, num_layers=3)
-            for _ in range(num_layers)
-        ])
-
-        ## Object query
-        if learnt_init_query:
-            self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
-        self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
-
-        ## Denoising part
-        self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim)
+        ## Two stage
+        self.enc_output = nn.Linear(d_model, d_model)
+        self.enc_output_norm = nn.LayerNorm(d_model)
+        self.pos_trans = nn.Linear(d_model * 2, d_model * 2)
+        self.pos_trans_norm = nn.LayerNorm(d_model * 2)
+
+        ## Expand layers
+        if proposal_feature_levels > 1:
+            assert len(proposal_tgt_strides) == proposal_feature_levels
+
+            self.enc_output_proj = nn.ModuleList([])
+            for stride in proposal_tgt_strides:
+                if stride == proposal_in_stride:
+                    self.enc_output_proj.append(nn.Identity())
+                elif stride > proposal_in_stride:
+                    scale = int(math.log2(stride / proposal_in_stride))
+                    layers = []
+                    for _ in range(scale - 1):
+                        layers += [
+                            nn.Conv2d(d_model, d_model, kernel_size=2, stride=2),
+                            LayerNorm2D(d_model),
+                            nn.GELU()
+                        ]
+                    layers.append(nn.Conv2d(d_model, d_model, kernel_size=2, stride=2))
+                    self.enc_output_proj.append(nn.Sequential(*layers))
+                else:
+                    scale = int(math.log2(proposal_in_stride / stride))
+                    layers = []
+                    for _ in range(scale - 1):
+                        layers += [
+                            nn.ConvTranspose2d(d_model, d_model, kernel_size=2, stride=2),
+                            LayerNorm2D(d_model),
+                            nn.GELU()
+                        ]
+                    layers.append(nn.ConvTranspose2d(d_model, d_model, kernel_size=2, stride=2))
+                    self.enc_output_proj.append(nn.Sequential(*layers))
 
         self._reset_parameters()
 
     def _reset_parameters(self):
-        def linear_init_(module):
-            bound = 1 / math.sqrt(module.weight.shape[0])
-            uniform_(module.weight, -bound, bound)
-            if hasattr(module, "bias") and module.bias is not None:
-                uniform_(module.bias, -bound, bound)
-
-        # class and bbox head init
-        prior_prob = 0.01
-        cls_bias_init = float(-math.log((1 - prior_prob) / prior_prob))
-        linear_init_(self.enc_class_head)
-        constant_(self.enc_class_head.bias, cls_bias_init)
-        constant_(self.enc_bbox_head.layers[-1].weight, 0.)
-        constant_(self.enc_bbox_head.layers[-1].bias, 0.)
-        for cls_, reg_ in zip(self.dec_class_head, self.dec_bbox_head):
-            linear_init_(cls_)
-            constant_(cls_.bias, cls_bias_init)
-            constant_(reg_.layers[-1].weight, 0.)
-            constant_(reg_.layers[-1].bias, 0.)
-
-        linear_init_(self.enc_output[0])
-        xavier_uniform_(self.enc_output[0].weight)
-        if self.learnt_init_query:
-            xavier_uniform_(self.tgt_embed.weight)
-        xavier_uniform_(self.query_pos_head.layers[0].weight)
-        xavier_uniform_(self.query_pos_head.layers[1].weight)
-        for l in self.input_proj_layers:
-            xavier_uniform_(l.conv.weight)
-        normal_(self.denoising_class_embed.weight)
-
-    def generate_anchors(self, spatial_shapes, grid_size=0.05):
-        anchors = []
-        for lvl, (h, w) in enumerate(spatial_shapes):
-            grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w))
-            # [H, W, 2]
-            grid_xy = torch.stack([grid_x, grid_y], dim=-1).float()
-
-            valid_WH = torch.as_tensor([w, h]).float()
-            grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
-            wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl)
-            # [H, W, 4] -> [1, N, 4], N=HxW
-            anchors.append(torch.cat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4))
-        # List[L, 1, N_i, 4] -> [1, N, 4], N=N_0 + N_1 + N_2 + ...
-        anchors = torch.cat(anchors, dim=1)
-        valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
-        anchors = torch.log(anchors / (1 - anchors))
-        # Equal to operation: anchors = torch.masked_fill(anchors, ~valid_mask, torch.as_tensor(float("inf")))
-        anchors = torch.where(valid_mask, anchors, torch.as_tensor(float("inf")))
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+
+        if hasattr(self.decoder, '_reset_parameters'):
+            print('decoder re-init')
+            self.decoder._reset_parameters()
+
+    def get_proposal_pos_embed(self, proposals):
+        num_pos_feats = self.d_model // 2
+        temperature = 10000
+        scale = 2 * torch.pi
+
+        dim_t = torch.arange(
+            num_pos_feats, dtype=torch.float32, device=proposals.device
+        )
+        dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
+        # N, L, 4
+        proposals = proposals * scale
+        # N, L, 4, 128
+        pos = proposals[:, :, :, None] / dim_t
+        # N, L, 4, 64, 2
+        pos = torch.stack(
+            (pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4
+        ).flatten(2)
+
+        return pos
+
+    def get_valid_ratio(self, mask):
+        _, H, W = mask.shape
+        valid_H = torch.sum(~mask[:, :, 0], 1)
+        valid_W = torch.sum(~mask[:, 0, :], 1)
+        valid_ratio_h = valid_H.float() / H
+        valid_ratio_w = valid_W.float() / W
+        valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+
+        return valid_ratio
+
+    def expand_encoder_output(self, memory, memory_padding_mask, spatial_shapes):
+        assert spatial_shapes.size(0) == 1, f'Get encoder output of shape {spatial_shapes}, not sure how to expand'
+
+        bs, _, c = memory.shape
+        h, w = spatial_shapes[0]
+
+        _out_memory = memory.view(bs, h, w, c).permute(0, 3, 1, 2)
+        _out_memory_padding_mask = memory_padding_mask.view(bs, h, w)
+
+        out_memory, out_memory_padding_mask, out_spatial_shapes = [], [], []
+        for i in range(self.proposal_feature_levels):
+            mem = self.enc_output_proj[i](_out_memory)
+            mask = F.interpolate(
+                _out_memory_padding_mask[None].float(), size=mem.shape[-2:]
+            ).to(torch.bool)
+
+            out_memory.append(mem)
+            out_memory_padding_mask.append(mask.squeeze(0))
+            out_spatial_shapes.append(mem.shape[-2:])
+
+        out_memory = torch.cat([mem.flatten(2).transpose(1, 2) for mem in out_memory], dim=1)
+        out_memory_padding_mask = torch.cat([mask.flatten(1) for mask in out_memory_padding_mask], dim=1)
+        out_spatial_shapes = torch.as_tensor(out_spatial_shapes, dtype=torch.long, device=out_memory.device)
         
-        return anchors, valid_mask
+        return out_memory, out_memory_padding_mask, out_spatial_shapes
+
+    def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
+        if self.proposal_feature_levels > 1:
+            memory, memory_padding_mask, spatial_shapes = self.expand_encoder_output(
+                memory, memory_padding_mask, spatial_shapes
+            )
+        N_, S_, C_ = memory.shape
+        # base_scale = 4.0
+        proposals = []
+        _cur = 0
+        for lvl, (H_, W_) in enumerate(spatial_shapes):
+            stride = self.proposal_tgt_strides[lvl]
+
+            grid_y, grid_x = torch.meshgrid(
+                torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
+                torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
+            )
+            grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
+            grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) * stride
+            wh = torch.ones_like(grid) * self.proposal_min_size * (2.0 ** lvl)
+            proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
+            proposals.append(proposal)
+            _cur += H_ * W_
+        output_proposals = torch.cat(proposals, 1)
+
+        H_, W_ = spatial_shapes[0]
+        stride = self.proposal_tgt_strides[0]
+        mask_flatten_ = memory_padding_mask[:, :H_*W_].view(N_, H_, W_, 1)
+        valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1, keepdim=True) * stride
+        valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1, keepdim=True) * stride
+        img_size = torch.cat([valid_W, valid_H, valid_W, valid_H], dim=-1)
+        img_size = img_size.unsqueeze(1) # [BS, 1, 4]
+
+        output_proposals_valid = (
+            (output_proposals > 0.01 * img_size) & (output_proposals < 0.99 * img_size)
+        ).all(-1, keepdim=True)
+        output_proposals = output_proposals.masked_fill(
+            memory_padding_mask.unsqueeze(-1).repeat(1, 1, 1),
+            max(H_, W_) * stride,
+        )
+        output_proposals = output_proposals.masked_fill(
+            ~output_proposals_valid,
+            max(H_, W_) * stride,
+        )
+
+        output_memory = memory
+        output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
+        output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
+        output_memory = self.enc_output_norm(self.enc_output(output_memory))
+
+        max_shape = (valid_H[:, None, :], valid_W[:, None, :])
+        return output_memory, output_proposals, max_shape
     
-    def get_encoder_input(self, feats):
-        # get projection features
-        proj_feats = [self.input_proj_layers[i](feat) for i, feat in enumerate(feats)]
-
-        # get encoder inputs
-        feat_flatten = []
-        spatial_shapes = []
-        level_start_index = [0, ]
-        for i, feat in enumerate(proj_feats):
-            _, _, h, w = feat.shape
-            spatial_shapes.append([h, w])
-            # [l], start index of each level
-            level_start_index.append(h * w + level_start_index[-1])
-            # [B, C, H, W] -> [B, N, C], N=HxW
-            feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
-
-        # [B, N, C], N = N_0 + N_1 + ...
-        feat_flatten = torch.cat(feat_flatten, dim=1)
-        level_start_index.pop()
-
-        return (feat_flatten, spatial_shapes, level_start_index)
-
-    def get_decoder_input(self,
-                          memory,
-                          spatial_shapes,
-                          denoising_class=None,
-                          denoising_bbox_unact=None):
-        bs, _, _ = memory.shape
-        # Prepare input for decoder
-        anchors, valid_mask = self.generate_anchors(spatial_shapes)
-        anchors = anchors.to(memory.device)
-        valid_mask = valid_mask.to(memory.device)
+    def get_reference_points(self, memory, mask_flatten, spatial_shapes):
+        output_memory, output_proposals, max_shape = self.gen_encoder_output_proposals(
+            memory, mask_flatten, spatial_shapes
+        )
+
+        # hack implementation for two-stage Deformable DETR
+        enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
+        enc_outputs_delta = self.decoder.bbox_embed[self.decoder.num_layers](output_memory)
+        enc_outputs_coord_unact = self.decoder.box_xyxy_to_cxcywh(self.decoder.delta2bbox(
+            output_proposals,
+            enc_outputs_delta,
+            max_shape
+        ))
+
+        topk = self.two_stage_num_proposals
+        topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
+        topk_coords_unact = torch.gather(
+            enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
+        )
+        topk_coords_unact = topk_coords_unact.detach()
+        reference_points = topk_coords_unact
         
-        # Process encoder's output
-        memory = torch.where(valid_mask, memory, torch.as_tensor(0., device=memory.device))
-        output_memory = self.enc_output(memory)
-
-        # Head for encoder's output : [bs, num_quries, c]
-        enc_outputs_class = self.enc_class_head(output_memory)
-        enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
-
-        # Topk proposals from encoder's output
-        topk = self.num_queries
-        topk_ind = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1]  # [bs, num_queries]
-        enc_topk_logits = torch.gather(
-            enc_outputs_class, 1, topk_ind.unsqueeze(-1).repeat(1, 1, self.num_classes))  # [bs, num_queries, nc]
-        reference_points_unact = torch.gather(
-            enc_outputs_coord_unact, 1, topk_ind.unsqueeze(-1).repeat(1, 1, 4))    # [bs, num_queries, 4]
-        enc_topk_bboxes = F.sigmoid(reference_points_unact)
-
-        if denoising_bbox_unact is not None:
-            reference_points_unact = torch.cat(
-                [denoising_bbox_unact, reference_points_unact], 1)
-        if self.training:
-            reference_points_unact = reference_points_unact.detach()
+        return (reference_points, max_shape, enc_outputs_class,
+                enc_outputs_coord_unact, enc_outputs_delta, output_proposals)
 
-        # Extract region features
-        if self.learnt_init_query:
-            # [num_queries, c] -> [b, num_queries, c]
-            target = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
-        else:
-            # [num_queries, c] -> [b, num_queries, c]
-            target = torch.gather(output_memory, 1, topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
-            if self.training:
-                target = target.detach()
-        if denoising_class is not None:
-            target = torch.cat([denoising_class, target], dim=1)
-
-        return target, reference_points_unact, enc_topk_bboxes, enc_topk_logits
-    
-    def forward(self, feats, targets=None):
-        # input projection and embedding
-        memory, spatial_shapes, _ = self.get_encoder_input(feats)
+    def forward(self, src, mask, pos_embed, query_embed=None, self_attn_mask=None):
+        # Prepare input for encoder
+        bs, c, h, w = src.shape
+        src_flatten = src.flatten(2).transpose(1, 2)
+        mask_flatten = mask.flatten(1)
+        pos_embed_flatten = pos_embed.flatten(2).transpose(1, 2)
+        spatial_shapes = torch.as_tensor([(h, w)], dtype=torch.long, device=src_flatten.device)
 
-        # prepare denoising training
+        # Prepare input for decoder
+        memory = src_flatten
+        bs, _, c = memory.shape
+       
+        # Two stage trick
         if self.training:
-            denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
-                get_contrastive_denoising_training_group(targets,
-                                                         self.num_classes,
-                                                         self.num_queries,
-                                                         self.denoising_class_embed.weight,
-                                                         self.num_denoising,
-                                                         self.label_noise_ratio,
-                                                         self.box_noise_scale)
+            self.two_stage_num_proposals = self.num_queries_one2one + self.num_queries_one2many
         else:
-            denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
-
-        target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
-            self.get_decoder_input(
-            memory, spatial_shapes, denoising_class, denoising_bbox_unact)
-
-        # decoder
-        out_bboxes, out_logits = self.decoder(target,
-                                              init_ref_points_unact,
-                                              memory,
-                                              spatial_shapes,
-                                              self.dec_bbox_head,
-                                              self.dec_class_head,
-                                              self.query_pos_head,
-                                              attn_mask)
-        
-        return out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta
+            self.two_stage_num_proposals = self.num_queries_one2one
+        (reference_points, max_shape, enc_outputs_class,
+        enc_outputs_coord_unact, enc_outputs_delta, output_proposals) \
+            = self.get_reference_points(memory, mask_flatten, spatial_shapes)
+        init_reference_out = reference_points
+        pos_trans_out = torch.zeros((bs, self.two_stage_num_proposals, 2*c), device=init_reference_out.device)
+        pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(reference_points)))
+
+        # Mixed selection trick
+        tgt = query_embed.unsqueeze(0).expand(bs, -1, -1)
+        query_embed, _ = torch.split(pos_trans_out, c, dim=2)
+
+        # Decoder
+        hs, inter_references = self.decoder(tgt,
+                                            reference_points,
+                                            memory,
+                                            pos_embed_flatten,
+                                            spatial_shapes,
+                                            query_embed,
+                                            mask_flatten,
+                                            self_attn_mask,
+                                            max_shape
+                                            )
+        inter_references_out = inter_references
+
+        return (hs,
+                init_reference_out,
+                inter_references_out,
+                enc_outputs_class,
+                enc_outputs_coord_unact,
+                enc_outputs_delta,
+                output_proposals,
+                max_shape
+                )
 
 
 # ----------------- Dencoder for Segmentation task -----------------
-## RTDETR's Transformer for Segmentation task
+## PlainDETR's Transformer for Segmentation task
 class SegTransformerDecoder(nn.Module):
     def __init__(self, ):
         super().__init__()
@@ -301,7 +324,7 @@ class SegTransformerDecoder(nn.Module):
 
 
 # ----------------- Dencoder for Pose estimation task -----------------
-## RTDETR's Transformer for Pose estimation task
+## PlainDETR's Transformer for Pose estimation task
 class PosTransformerDecoder(nn.Module):
     def __init__(self, ):
         super().__init__()
@@ -314,50 +337,66 @@ class PosTransformerDecoder(nn.Module):
 if __name__ == '__main__':
     import time
     from thop import profile
+    from basic_modules.basic import MLP
+    from basic_modules.transformer import get_clones
+
     cfg = {
-        'out_stride': [8, 16, 32],
+        'out_stride': 16,
         # Transformer Decoder
-        'transformer': 'rtdetr_transformer',
+        'transformer': 'plain_detr_transformer',
         'hidden_dim': 256,
+        'num_queries': 300,
         'de_num_heads': 8,
         'de_num_layers': 6,
         'de_mlp_ratio': 4.0,
         'de_dropout': 0.1,
         'de_act': 'gelu',
-        'de_num_points': 4,
-        'num_queries': 300,
-        'learnt_init_query': False,
-        'pe_temperature': 10000.,
-        'dn_num_denoising': 100,
-        'dn_label_noise_ratio': 0.5,
-        'dn_box_noise_scale': 1,
+        'de_pre_norm': True,
+        'rpe_hidden_dim': 512,
+        'use_checkpoint': False,
+        'proposal_feature_levels': 3,
+        'proposal_tgt_strides': [8, 16, 32],
     }
-    bs = 1
-    hidden_dim = cfg['hidden_dim']
-    in_dims = [hidden_dim] * 3
-    targets = [{
-        'labels': torch.tensor([2, 4, 5, 8]).long(),
-        'boxes':  torch.tensor([[0, 0, 10, 10], [12, 23, 56, 70], [0, 10, 20, 30], [50, 60, 55, 150]]).float()
-    }] * bs
-    pyramid_feats = [torch.randn(bs, hidden_dim, 80, 80),
-                     torch.randn(bs, hidden_dim, 40, 40),
-                     torch.randn(bs, hidden_dim, 20, 20)]
-    model = build_transformer(cfg, in_dims, 80, True)
-    model.train()
+    feat = torch.randn(1, cfg['hidden_dim'], 40, 40)
+    mask = torch.zeros(1, 40, 40)
+    pos_embed = torch.randn(1, cfg['hidden_dim'], 40, 40)
+    query_embed = torch.randn(cfg['num_queries'], cfg['hidden_dim'])
 
+    model = build_transformer(cfg, True)
+
+    class_embed = nn.Linear(cfg['hidden_dim'], 80)
+    bbox_embed = MLP(cfg['hidden_dim'], cfg['hidden_dim'], 4, 3)
+    class_embed = get_clones(class_embed, cfg['de_num_layers'] + 1)
+    bbox_embed = get_clones(bbox_embed, cfg['de_num_layers'] + 1)
+
+    model.decoder.bbox_embed = bbox_embed
+    model.decoder.class_embed = class_embed
+
+    model.train()
     t0 = time.time()
-    outputs = model(pyramid_feats, targets)
-    out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta = outputs
+    outputs = model(feat, mask, pos_embed, query_embed)
+    (hs,
+     init_reference_out,
+     inter_references_out,
+     enc_outputs_class,
+     enc_outputs_coord_unact,
+     enc_outputs_delta,
+     output_proposals,
+     max_shape
+     ) = outputs
     t1 = time.time()
     print('Time: ', t1 - t0)
-    print(out_bboxes.shape)
-    print(out_logits.shape)
-    print(enc_topk_bboxes.shape)
-    print(enc_topk_logits.shape)
+    print(hs.shape)
+    print(init_reference_out.shape)
+    print(inter_references_out.shape)
+    print(enc_outputs_class.shape)
+    print(enc_outputs_coord_unact.shape)
+    print(enc_outputs_delta.shape)
+    print(output_proposals.shape)
 
     print('==============================')
     model.eval()
-    flops, params = profile(model, inputs=(pyramid_feats, ), verbose=False)
+    flops, params = profile(model, inputs=(feat, mask, pos_embed, query_embed, ), verbose=False)
     print('==============================')
     print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
     print('Params : {:.2f} M'.format(params / 1e6))