Selaa lähdekoodia

RandomJitterCrop

yjh0410 1 vuosi sitten
vanhempi
sitoutus
49b72e2035

+ 73 - 2
dataset/data_augment/rtdetr_augment.py

@@ -263,6 +263,78 @@ class RandomSampleCrop(object):
 
                 return current_image, target
 
+## Random JitterCrop
+class RandomJitterCrop(object):
+    """Jitter and crop the image and box."""
+    def __init__(self, fill_value, p=0.5, jitter_ratio=0.3):
+        super().__init__()
+        self.p = p
+        self.jitter_ratio = jitter_ratio
+        self.fill_value = fill_value
+
+    def crop(self, image, pleft, pright, ptop, pbot, output_size):
+        oh, ow = image.shape[:2]
+
+        swidth, sheight = output_size
+
+        src_rect = [pleft, ptop, swidth + pleft,
+                    sheight + ptop]  # x1,y1,x2,y2
+        img_rect = [0, 0, ow, oh]
+        # rect intersection
+        new_src_rect = [max(src_rect[0], img_rect[0]),
+                        max(src_rect[1], img_rect[1]),
+                        min(src_rect[2], img_rect[2]),
+                        min(src_rect[3], img_rect[3])]
+        dst_rect = [max(0, -pleft),
+                    max(0, -ptop),
+                    max(0, -pleft) + new_src_rect[2] - new_src_rect[0],
+                    max(0, -ptop) + new_src_rect[3] - new_src_rect[1]]
+
+        # crop the image
+        cropped = np.ones([sheight, swidth, 3], dtype=image.dtype) * self.fill_value
+        # cropped[:, :, ] = np.mean(image, axis=(0, 1))
+        cropped[dst_rect[1]:dst_rect[3], dst_rect[0]:dst_rect[2]] = \
+            image[new_src_rect[1]:new_src_rect[3],
+            new_src_rect[0]:new_src_rect[2]]
+
+        return cropped
+
+    def __call__(self, image, target=None):
+        if random.random() > self.p:
+            return image, target
+        else:
+            oh, ow = image.shape[:2]
+            dw = int(ow * self.jitter_ratio)
+            dh = int(oh * self.jitter_ratio)
+            pleft = np.random.randint(-dw, dw)
+            pright = np.random.randint(-dw, dw)
+            ptop = np.random.randint(-dh, dh)
+            pbot = np.random.randint(-dh, dh)
+
+            swidth = ow - pleft - pright
+            sheight = oh - ptop - pbot
+            output_size = (swidth, sheight)
+            # crop image
+            cropped_image = self.crop(image=image,
+                                    pleft=pleft, 
+                                    pright=pright, 
+                                    ptop=ptop, 
+                                    pbot=pbot,
+                                    output_size=output_size)
+            # crop bbox
+            if target is not None:
+                bboxes = target['boxes'].copy()
+                coords_offset = np.array([pleft, ptop], dtype=np.float32)
+                bboxes[..., [0, 2]] = bboxes[..., [0, 2]] - coords_offset[0]
+                bboxes[..., [1, 3]] = bboxes[..., [1, 3]] - coords_offset[1]
+                swidth, sheight = output_size
+
+                bboxes[..., [0, 2]] = np.clip(bboxes[..., [0, 2]], 0, swidth - 1)
+                bboxes[..., [1, 3]] = np.clip(bboxes[..., [1, 3]], 0, sheight - 1)
+                target['boxes'] = bboxes
+
+            return cropped_image, target
+    
 ## Random HFlip
 class RandomHorizontalFlip(object):
     def __init__(self, p=0.5):
@@ -355,8 +427,7 @@ class RTDetrAugmentation(object):
             # For no-mosaic setting, we use RandomExpand & RandomSampleCrop processor.
             self.augment = Compose([
                 RandomPhotometricDistort(hue=0.5, saturation=1.5, exposure=1.5),
-                RandomExpand(self.pixel_mean[::-1]),
-                RandomSampleCrop(),
+                RandomJitterCrop(p=0.8, jitter_ratio=0.3, fill_value=self.pixel_mean[::-1]),
                 RandomHorizontalFlip(p=0.5),
                 Resize(img_size=self.img_size),
                 ConvertColorFormat(self.color_format),

+ 1 - 1
engine.py

@@ -1146,7 +1146,7 @@ class RTRTrainer(object):
 
         # ---------------------------- Hyperparameters refer to RTMDet ----------------------------
         self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 1e-4, 'lr0': 0.0001, 'backbone_lr_ratio': 0.1}
-        self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 1.0, 'warmup_iters': 2000} # no lr decay
+        self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.1, 'warmup_iters': 2000} # no lr decay
         self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
 
         # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------

+ 24 - 1
models/detectors/rtdetr/basic_modules/transformer.py

@@ -4,7 +4,7 @@ import copy
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-from torch.nn.init import constant_, xavier_uniform_
+from torch.nn.init import constant_, xavier_uniform_, uniform_
 
 try:
     from .basic import get_activation
@@ -260,9 +260,19 @@ class TransformerEncoderLayer(nn.Module):
         # Feedforwaed Network
         self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
 
+        self._reset_parameters()
+
     def with_pos_embed(self, tensor, pos):
         return tensor if pos is None else tensor + pos
 
+    def _reset_parameters(self):
+        def linear_init_(module):
+            bound = 1 / math.sqrt(module.weight.shape[0])
+            uniform_(module.weight, -bound, bound)
+            if hasattr(module, "bias") and module.bias is not None:
+                uniform_(module.bias, -bound, bound)
+        linear_init_(self.ffn.linear1)
+        linear_init_(self.ffn.linear2)
 
     def forward(self, src, pos_embed):
         """
@@ -395,9 +405,22 @@ class DeformableTransformerDecoderLayer(nn.Module):
         ## FFN
         self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
 
+        self._reset_parameters()
+
     def with_pos_embed(self, tensor, pos):
         return tensor if pos is None else tensor + pos
 
+    def _reset_parameters(self):
+        def linear_init_(module):
+            bound = 1 / math.sqrt(module.weight.shape[0])
+            uniform_(module.weight, -bound, bound)
+            if hasattr(module, "bias") and module.bias is not None:
+                uniform_(module.bias, -bound, bound)
+        linear_init_(self.ffn.linear1)
+        linear_init_(self.ffn.linear2)
+        xavier_uniform_(self.ffn.linear1.weight)
+        xavier_uniform_(self.ffn.linear2.weight)
+
     def forward(self,
                 tgt,
                 reference_points,

+ 104 - 0
models/detectors/rtpdetr/basic_modules/basic.py

@@ -1,7 +1,60 @@
+import math
 import torch
 import torch.nn as nn
 
 
+# ----------------- Customed NormLayer Ops -----------------
+class FrozenBatchNorm2d(torch.nn.Module):
+    def __init__(self, n):
+        super(FrozenBatchNorm2d, self).__init__()
+        self.register_buffer("weight", torch.ones(n))
+        self.register_buffer("bias", torch.zeros(n))
+        self.register_buffer("running_mean", torch.zeros(n))
+        self.register_buffer("running_var", torch.ones(n))
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        num_batches_tracked_key = prefix + 'num_batches_tracked'
+        if num_batches_tracked_key in state_dict:
+            del state_dict[num_batches_tracked_key]
+
+        super(FrozenBatchNorm2d, self)._load_from_state_dict(
+            state_dict, prefix, local_metadata, strict,
+            missing_keys, unexpected_keys, error_msgs)
+
+    def forward(self, x):
+        # move reshapes to the beginning
+        # to make it fuser-friendly
+        w = self.weight.reshape(1, -1, 1, 1)
+        b = self.bias.reshape(1, -1, 1, 1)
+        rv = self.running_var.reshape(1, -1, 1, 1)
+        rm = self.running_mean.reshape(1, -1, 1, 1)
+        eps = 1e-5
+        scale = w * (rv + eps).rsqrt()
+        bias = b - rm * scale
+        return x * scale + bias
+
+class LayerNorm2D(nn.Module):
+    def __init__(self, normalized_shape, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.ln = norm_layer(normalized_shape) if norm_layer is not None else nn.Identity()
+
+    def forward(self, x):
+        """
+        x: N C H W
+        """
+        x = x.permute(0, 2, 3, 1)
+        x = self.ln(x)
+        x = x.permute(0, 3, 1, 2)
+        return x
+
+
+# ----------------- Basic CNN Ops -----------------
+def get_conv2d(c1, c2, k, p, s, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, groups=g, bias=bias)
+
+    return conv
+
 def get_activation(act_type=None):
     if act_type == 'relu':
         return nn.ReLU(inplace=True)
@@ -29,6 +82,57 @@ def get_norm(norm_type, dim):
         raise NotImplementedError
 
 
+class BasicConv(nn.Module):
+    def __init__(self, 
+                 in_dim,                   # in channels
+                 out_dim,                  # out channels 
+                 kernel_size=1,            # kernel size 
+                 padding=0,                # padding
+                 stride=1,                 # padding
+                 act_type  :str = 'lrelu', # activation
+                 norm_type :str = 'BN',    # normalization
+                ):
+        super(BasicConv, self).__init__()
+        add_bias = False if norm_type else True
+        self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
+        self.norm = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        return self.act(self.norm(self.conv(x)))
+
+class UpSampleWrapper(nn.Module):
+    """Upsample last feat map to specific stride."""
+    def __init__(self, in_dim, upsample_factor):
+        super(UpSampleWrapper, self).__init__()
+        # ---------- Basic parameters ----------
+        self.upsample_factor = upsample_factor
+
+        # ---------- Network parameters ----------
+        if upsample_factor == 1:
+            self.upsample = nn.Identity()
+        else:
+            scale = int(math.log2(upsample_factor))
+            dim = in_dim
+            layers = []
+            for _ in range(scale-1):
+                layers += [
+                    nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
+                    LayerNorm2D(dim // 2),
+                    nn.GELU()
+                ]
+                dim = dim // 2
+            layers += [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)]
+            dim = dim // 2
+            self.upsample = nn.Sequential(*layers)
+            self.out_dim = dim
+
+    def forward(self, x):
+        x = self.upsample(x)
+
+        return x
+
+
 # ----------------- MLP modules -----------------
 class MLP(nn.Module):
     def __init__(self, in_dim, hidden_dim, out_dim, num_layers):

+ 2 - 2
models/detectors/rtpdetr/basic_modules/transformer.py

@@ -153,7 +153,7 @@ class TransformerEncoder(nn.Module):
         return src
 
 ## Transformer Decoder layer
-class TransformerDecoderLayer(nn.Module):
+class PlainTransformerDecoderLayer(nn.Module):
     def __init__(self,
                  d_model     :int   = 256,
                  num_heads   :int   = 8,
@@ -221,7 +221,7 @@ class TransformerDecoderLayer(nn.Module):
         return tgt
 
 ## Transformer Decoder
-class TransformerDecoder(nn.Module):
+class PlainTransformerDecoder(nn.Module):
     def __init__(self,
                  d_model        :int   = 256,
                  num_heads      :int   = 8,

+ 363 - 0
models/detectors/rtpdetr/rtpdetr_decoder.py

@@ -0,0 +1,363 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.init import constant_, xavier_uniform_, uniform_, normal_
+from typing import List
+
+try:
+    from .basic_modules.basic import BasicConv, MLP
+    from .basic_modules.transformer import PlainTransformerDecoder
+except:
+    from  basic_modules.basic import BasicConv, MLP
+    from  basic_modules.transformer import PlainTransformerDecoder
+
+
+def build_transformer(cfg, in_dims, num_classes, return_intermediate=False):
+    if cfg['transformer'] == 'plain_detr_transformer':
+        return PlainDETRTransformer(in_dims             = in_dims,
+                                 hidden_dim          = cfg['hidden_dim'],
+                                 strides             = cfg['out_stride'],
+                                 num_classes         = num_classes,
+                                 num_queries         = cfg['num_queries'],
+                                 pos_embed_type      = 'sine',
+                                 num_heads           = cfg['de_num_heads'],
+                                 num_layers          = cfg['de_num_layers'],
+                                 num_levels          = len(cfg['out_stride']),
+                                 num_points          = cfg['de_num_points'],
+                                 mlp_ratio           = cfg['de_mlp_ratio'],
+                                 dropout             = cfg['de_dropout'],
+                                 act_type            = cfg['de_act'],
+                                 return_intermediate = return_intermediate,
+                                 num_denoising       = cfg['dn_num_denoising'],
+                                 label_noise_ratio   = cfg['dn_label_noise_ratio'],
+                                 box_noise_scale     = cfg['dn_box_noise_scale'],
+                                 learnt_init_query   = cfg['learnt_init_query'],
+                                 )
+
+
+# ----------------- Dencoder for Detection task -----------------
+## RTDETR's Transformer for Detection task
+class PlainDETRTransformer(nn.Module):
+    def __init__(self,
+                 # basic parameters
+                 in_dims        :List = [256, 512, 1024],
+                 hidden_dim     :int  = 256,
+                 strides        :List = [8, 16, 32],
+                 num_classes    :int  = 80,
+                 num_queries    :int  = 300,
+                 pos_embed_type :str  = 'sine',
+                 # transformer parameters
+                 num_heads      :int   = 8,
+                 num_layers     :int   = 1,
+                 num_levels     :int   = 3,
+                 num_points     :int   = 4,
+                 mlp_ratio      :float = 4.0,
+                 dropout        :float = 0.1,
+                 act_type       :str   = "relu",
+                 return_intermediate :bool = False,
+                 # Denoising parameters
+                 num_denoising       :int  = 100,
+                 label_noise_ratio   :float = 0.5,
+                 box_noise_scale     :float = 1.0,
+                 learnt_init_query   :bool  = True,
+                 ):
+        super().__init__()
+        # --------------- Basic setting ---------------
+        ## Basic parameters
+        self.in_dims = in_dims
+        self.strides = strides
+        self.num_queries = num_queries
+        self.pos_embed_type = pos_embed_type
+        self.num_classes = num_classes
+        self.eps = 1e-2
+        ## Transformer parameters
+        self.num_heads  = num_heads
+        self.num_layers = num_layers
+        self.num_levels = num_levels
+        self.num_points = num_points
+        self.mlp_ratio  = mlp_ratio
+        self.dropout    = dropout
+        self.act_type   = act_type
+        self.return_intermediate = return_intermediate
+        ## Denoising parameters
+        self.num_denoising = num_denoising
+        self.label_noise_ratio = label_noise_ratio
+        self.box_noise_scale = box_noise_scale
+        self.learnt_init_query = learnt_init_query
+
+        # --------------- Network setting ---------------
+        ## Input proj layers
+        self.input_proj_layers = nn.ModuleList(
+            BasicConv(in_dims[i], hidden_dim, kernel_size=1, act_type=None, norm_type="BN")
+            for i in range(num_levels)
+        )
+
+        ## Deformable transformer decoder
+        self.decoder = PlainTransformerDecoder(
+                                    d_model    = hidden_dim,
+                                    num_heads  = num_heads,
+                                    num_layers = num_layers,
+                                    num_levels = num_levels,
+                                    num_points = num_points,
+                                    mlp_ratio  = mlp_ratio,
+                                    dropout    = dropout,
+                                    act_type   = act_type,
+                                    return_intermediate = return_intermediate
+                                    )
+        
+        ## Detection head for Encoder
+        self.enc_output = nn.Sequential(
+            nn.Linear(hidden_dim, hidden_dim),
+            nn.LayerNorm(hidden_dim)
+            )
+        self.enc_class_head = nn.Linear(hidden_dim, num_classes)
+        self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
+
+        ## Detection head for Decoder
+        self.dec_class_head = nn.ModuleList([
+            nn.Linear(hidden_dim, num_classes)
+            for _ in range(num_layers)
+        ])
+        self.dec_bbox_head = nn.ModuleList([
+            MLP(hidden_dim, hidden_dim, 4, num_layers=3)
+            for _ in range(num_layers)
+        ])
+
+        ## Object query
+        if learnt_init_query:
+            self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
+        self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
+
+        ## Denoising part
+        self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim)
+
+        self._reset_parameters()
+
+    def _reset_parameters(self):
+        def linear_init_(module):
+            bound = 1 / math.sqrt(module.weight.shape[0])
+            uniform_(module.weight, -bound, bound)
+            if hasattr(module, "bias") and module.bias is not None:
+                uniform_(module.bias, -bound, bound)
+
+        # class and bbox head init
+        prior_prob = 0.01
+        cls_bias_init = float(-math.log((1 - prior_prob) / prior_prob))
+        linear_init_(self.enc_class_head)
+        constant_(self.enc_class_head.bias, cls_bias_init)
+        constant_(self.enc_bbox_head.layers[-1].weight, 0.)
+        constant_(self.enc_bbox_head.layers[-1].bias, 0.)
+        for cls_, reg_ in zip(self.dec_class_head, self.dec_bbox_head):
+            linear_init_(cls_)
+            constant_(cls_.bias, cls_bias_init)
+            constant_(reg_.layers[-1].weight, 0.)
+            constant_(reg_.layers[-1].bias, 0.)
+
+        linear_init_(self.enc_output[0])
+        xavier_uniform_(self.enc_output[0].weight)
+        if self.learnt_init_query:
+            xavier_uniform_(self.tgt_embed.weight)
+        xavier_uniform_(self.query_pos_head.layers[0].weight)
+        xavier_uniform_(self.query_pos_head.layers[1].weight)
+        for l in self.input_proj_layers:
+            xavier_uniform_(l.conv.weight)
+        normal_(self.denoising_class_embed.weight)
+
+    def generate_anchors(self, spatial_shapes, grid_size=0.05):
+        anchors = []
+        for lvl, (h, w) in enumerate(spatial_shapes):
+            grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w))
+            # [H, W, 2]
+            grid_xy = torch.stack([grid_x, grid_y], dim=-1).float()
+
+            valid_WH = torch.as_tensor([w, h]).float()
+            grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
+            wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl)
+            # [H, W, 4] -> [1, N, 4], N=HxW
+            anchors.append(torch.cat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4))
+        # List[L, 1, N_i, 4] -> [1, N, 4], N=N_0 + N_1 + N_2 + ...
+        anchors = torch.cat(anchors, dim=1)
+        valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
+        anchors = torch.log(anchors / (1 - anchors))
+        # Equal to operation: anchors = torch.masked_fill(anchors, ~valid_mask, torch.as_tensor(float("inf")))
+        anchors = torch.where(valid_mask, anchors, torch.as_tensor(float("inf")))
+        
+        return anchors, valid_mask
+    
+    def get_encoder_input(self, feats):
+        # get projection features
+        proj_feats = [self.input_proj_layers[i](feat) for i, feat in enumerate(feats)]
+
+        # get encoder inputs
+        feat_flatten = []
+        spatial_shapes = []
+        level_start_index = [0, ]
+        for i, feat in enumerate(proj_feats):
+            _, _, h, w = feat.shape
+            spatial_shapes.append([h, w])
+            # [l], start index of each level
+            level_start_index.append(h * w + level_start_index[-1])
+            # [B, C, H, W] -> [B, N, C], N=HxW
+            feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
+
+        # [B, N, C], N = N_0 + N_1 + ...
+        feat_flatten = torch.cat(feat_flatten, dim=1)
+        level_start_index.pop()
+
+        return (feat_flatten, spatial_shapes, level_start_index)
+
+    def get_decoder_input(self,
+                          memory,
+                          spatial_shapes,
+                          denoising_class=None,
+                          denoising_bbox_unact=None):
+        bs, _, _ = memory.shape
+        # Prepare input for decoder
+        anchors, valid_mask = self.generate_anchors(spatial_shapes)
+        anchors = anchors.to(memory.device)
+        valid_mask = valid_mask.to(memory.device)
+        
+        # Process encoder's output
+        memory = torch.where(valid_mask, memory, torch.as_tensor(0., device=memory.device))
+        output_memory = self.enc_output(memory)
+
+        # Head for encoder's output : [bs, num_quries, c]
+        enc_outputs_class = self.enc_class_head(output_memory)
+        enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
+
+        # Topk proposals from encoder's output
+        topk = self.num_queries
+        topk_ind = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1]  # [bs, num_queries]
+        enc_topk_logits = torch.gather(
+            enc_outputs_class, 1, topk_ind.unsqueeze(-1).repeat(1, 1, self.num_classes))  # [bs, num_queries, nc]
+        reference_points_unact = torch.gather(
+            enc_outputs_coord_unact, 1, topk_ind.unsqueeze(-1).repeat(1, 1, 4))    # [bs, num_queries, 4]
+        enc_topk_bboxes = F.sigmoid(reference_points_unact)
+
+        if denoising_bbox_unact is not None:
+            reference_points_unact = torch.cat(
+                [denoising_bbox_unact, reference_points_unact], 1)
+        if self.training:
+            reference_points_unact = reference_points_unact.detach()
+
+        # Extract region features
+        if self.learnt_init_query:
+            # [num_queries, c] -> [b, num_queries, c]
+            target = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
+        else:
+            # [num_queries, c] -> [b, num_queries, c]
+            target = torch.gather(output_memory, 1, topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
+            if self.training:
+                target = target.detach()
+        if denoising_class is not None:
+            target = torch.cat([denoising_class, target], dim=1)
+
+        return target, reference_points_unact, enc_topk_bboxes, enc_topk_logits
+    
+    def forward(self, feats, targets=None):
+        # input projection and embedding
+        memory, spatial_shapes, _ = self.get_encoder_input(feats)
+
+        # prepare denoising training
+        if self.training:
+            denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
+                get_contrastive_denoising_training_group(targets,
+                                                         self.num_classes,
+                                                         self.num_queries,
+                                                         self.denoising_class_embed.weight,
+                                                         self.num_denoising,
+                                                         self.label_noise_ratio,
+                                                         self.box_noise_scale)
+        else:
+            denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
+
+        target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
+            self.get_decoder_input(
+            memory, spatial_shapes, denoising_class, denoising_bbox_unact)
+
+        # decoder
+        out_bboxes, out_logits = self.decoder(target,
+                                              init_ref_points_unact,
+                                              memory,
+                                              spatial_shapes,
+                                              self.dec_bbox_head,
+                                              self.dec_class_head,
+                                              self.query_pos_head,
+                                              attn_mask)
+        
+        return out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta
+
+
+# ----------------- Dencoder for Segmentation task -----------------
+## RTDETR's Transformer for Segmentation task
+class SegTransformerDecoder(nn.Module):
+    def __init__(self, ):
+        super().__init__()
+        # TODO: design seg-decoder
+
+    def forward(self, x):
+        return
+
+
+# ----------------- Dencoder for Pose estimation task -----------------
+## RTDETR's Transformer for Pose estimation task
+class PosTransformerDecoder(nn.Module):
+    def __init__(self, ):
+        super().__init__()
+        # TODO: design seg-decoder
+
+    def forward(self, x):
+        return
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'out_stride': [8, 16, 32],
+        # Transformer Decoder
+        'transformer': 'rtdetr_transformer',
+        'hidden_dim': 256,
+        'de_num_heads': 8,
+        'de_num_layers': 6,
+        'de_mlp_ratio': 4.0,
+        'de_dropout': 0.1,
+        'de_act': 'gelu',
+        'de_num_points': 4,
+        'num_queries': 300,
+        'learnt_init_query': False,
+        'pe_temperature': 10000.,
+        'dn_num_denoising': 100,
+        'dn_label_noise_ratio': 0.5,
+        'dn_box_noise_scale': 1,
+    }
+    bs = 1
+    hidden_dim = cfg['hidden_dim']
+    in_dims = [hidden_dim] * 3
+    targets = [{
+        'labels': torch.tensor([2, 4, 5, 8]).long(),
+        'boxes':  torch.tensor([[0, 0, 10, 10], [12, 23, 56, 70], [0, 10, 20, 30], [50, 60, 55, 150]]).float()
+    }] * bs
+    pyramid_feats = [torch.randn(bs, hidden_dim, 80, 80),
+                     torch.randn(bs, hidden_dim, 40, 40),
+                     torch.randn(bs, hidden_dim, 20, 20)]
+    model = build_transformer(cfg, in_dims, 80, True)
+    model.train()
+
+    t0 = time.time()
+    outputs = model(pyramid_feats, targets)
+    out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta = outputs
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print(out_bboxes.shape)
+    print(out_logits.shape)
+    print(enc_topk_bboxes.shape)
+    print(enc_topk_logits.shape)
+
+    print('==============================')
+    model.eval()
+    flops, params = profile(model, inputs=(pyramid_feats, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 78 - 0
models/detectors/rtpdetr/rtpdetr_encoder.py

@@ -0,0 +1,78 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+try:
+    from .basic_modules.basic    import BasicConv, UpSampleWrapper
+    from .basic_modules.backbone import build_backbone
+except:
+    from  basic_modules.basic    import BasicConv, UpSampleWrapper
+    from  basic_modules.backbone import build_backbone
+
+
+# ----------------- Image Encoder -----------------
+def build_image_encoder(cfg):
+    return ImageEncoder(cfg)
+
+class ImageEncoder(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        # ---------------- Basic settings ----------------
+        ## Basic parameters
+        self.cfg = cfg
+        ## Network parameters
+        self.stride = cfg['out_stride']
+        self.upsample_factor = 32 // self.stride
+        self.hidden_dim = cfg['hidden_dim']
+        
+        # ---------------- Network settings ----------------
+        ## Backbone Network
+        self.backbone, fpn_feat_dims = build_backbone(cfg, pretrained=cfg['pretrained']&self.training)
+
+        ## Upsample layer
+        self.upsample = UpSampleWrapper(fpn_feat_dims[-1], self.upsample_factor)
+        
+        ## Input projection
+        self.input_proj = BasicConv(self.upsample.out_dim, self.hidden_dim, kernel_size=1, act_type=None, norm_type='BN')
+
+
+    def forward(self, x):
+        pyramid_feats = self.backbone(x)
+        feat = self.upsample(pyramid_feats[-1])
+        feat = self.input_proj(feat)
+
+        return feat
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'width': 1.0,
+        'depth': 1.0,
+        'out_stride': 16,
+        # Image Encoder - Backbone
+        'backbone': 'resnet50',
+        'backbone_norm': 'BN',
+        'res5_dilation': False,
+        'pretrained': True,
+        'pretrained_weight': 'imagenet1k_v1',        
+        'hidden_dim': 256,
+    }
+    x = torch.rand(2, 3, 640, 640)
+    model = build_image_encoder(cfg)
+    model.train()
+
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print(outputs.shape)
+
+    print('==============================')
+    model.eval()
+    x = torch.rand(1, 3, 640, 640)
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))