فهرست منبع

add RT-DETR's modules

yjh0410 1 سال پیش
والد
کامیت
996544c30c

+ 31 - 1
config/data_config/transform_config.py

@@ -272,9 +272,39 @@ ssd_trans_config = {
 
 
 # ----------------------- SSD-Style Transform -----------------------
-rtdetr_trans_config = {
+rtdetr_base_trans_config = {
     'aug_type': 'rtdetr',
     'use_ablu': False,
+    'pixel_mean': [123.675, 116.28, 103.53],  # IN-1K statistics
+    'pixel_std':  [58.395, 57.12, 57.375],    # IN-1K statistics
+    # Mosaic & Mixup are not used for RT_DETR-style augmentation
+    'mosaic_prob': 0.,
+    'mixup_prob': 0.,
+    'mosaic_type': 'yolov5_mosaic',
+    'mixup_type': 'yolov5_mixup',
+    'mosaic_keep_ratio': False,
+    'mixup_scale': [0.5, 1.5]
+}
+
+rtdetr_l_trans_config = {
+    'aug_type': 'rtdetr',
+    'use_ablu': False,
+    'pixel_mean': [0., 0., 0.],
+    'pixel_std':  [255., 255., 255.],
+    # Mosaic & Mixup are not used for RT_DETR-style augmentation
+    'mosaic_prob': 0.,
+    'mixup_prob': 0.,
+    'mosaic_type': 'yolov5_mosaic',
+    'mixup_type': 'yolov5_mixup',
+    'mosaic_keep_ratio': False,
+    'mixup_scale': [0.5, 1.5]
+}
+
+rtdetr_x_trans_config = {
+    'aug_type': 'rtdetr',
+    'use_ablu': False,
+    'pixel_mean': [0., 0., 0.],
+    'pixel_std':  [255., 255., 255.],
     # Mosaic & Mixup are not used for RT_DETR-style augmentation
     'mosaic_prob': 0.,
     'mixup_prob': 0.,

+ 7 - 14
dataset/build.py

@@ -108,33 +108,26 @@ def build_transform(args, trans_config, max_stride=32, is_train=False):
             trans_config['mixup_prob'] = args.mixup
 
     # ---------------- Build transform ----------------
-    ## SSD-style transform
+    ## SSD style transform
     if trans_config['aug_type'] == 'ssd':
         if is_train:
             transform = SSDAugmentation(img_size=args.img_size,)
         else:
             transform = SSDBaseTransform(img_size=args.img_size,)
-    ## YOLO-style transform
+    ## YOLO style transform
     elif trans_config['aug_type'] == 'yolov5':
         if is_train:
-            transform = YOLOv5Augmentation(
-                img_size=args.img_size,
-                trans_config=trans_config,
-                use_ablu=trans_config['use_ablu']
-                )
+            transform = YOLOv5Augmentation(img_size=args.img_size, trans_config=trans_config, use_ablu=trans_config['use_ablu'])
         else:
-            transform = YOLOv5BaseTransform(
-                img_size=args.img_size,
-                max_stride=max_stride
-                )
-    ## RT_DETR-style transform
+            transform = YOLOv5BaseTransform(img_size=args.img_size,max_stride=max_stride)
+    ## RT-DETR style transform
     elif trans_config['aug_type'] == 'rtdetr':
         if is_train:
             use_mosaic = False if trans_config['mosaic_prob'] < 0.2 else True
             transform = RTDetrAugmentation(
-                img_size=args.img_size, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], use_mosaic=use_mosaic)
+                img_size=args.img_size, pixel_mean=trans_config['pixel_mean'], pixel_std=trans_config['pixel_std'], use_mosaic=use_mosaic)
         else:
             transform = RTDetrBaseTransform(
-                img_size=args.img_size, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375])
+                img_size=args.img_size, pixel_mean=trans_config['pixel_mean'], pixel_std=trans_config['pixel_std'])
 
     return transform, trans_config

+ 2 - 0
dataset/coco.py

@@ -269,6 +269,8 @@ if __name__ == "__main__":
 
     trans_config = {
         'aug_type': args.aug_type,    # optional: ssd, yolov5
+        'pixel_mean': [0., 0., 0.],
+        'pixel_std':  [255., 255., 255.],
         # Basic Augment
         'degrees': 0.0,
         'translate': 0.2,

+ 2 - 0
dataset/crowdhuman.py

@@ -204,6 +204,8 @@ if __name__ == "__main__":
 
     trans_config = {
         'aug_type': args.aug_type,    # optional: ssd, yolov5
+        'pixel_mean': [0., 0., 0.],
+        'pixel_std':  [255., 255., 255.],
         # Basic Augment
         'degrees': 0.0,
         'translate': 0.2,

+ 2 - 0
dataset/customed.py

@@ -263,6 +263,8 @@ if __name__ == "__main__":
 
     trans_config = {
         'aug_type': args.aug_type,    # optional: ssd, yolov5
+        'pixel_mean': [0., 0., 0.],
+        'pixel_std':  [255., 255., 255.],
         # Basic Augment
         'degrees': 0.0,
         'translate': 0.2,

+ 6 - 0
dataset/data_augment/rtdetr_augment.py

@@ -307,6 +307,9 @@ class RTDetrAugmentation(object):
         self.pixel_mean = pixel_mean  # RGB format
         self.pixel_std = pixel_std    # RGB format
         self.color_format = 'rgb'
+        print("================= Pixel Statistics =================")
+        print("Pixel mean: {}".format(self.pixel_mean))
+        print("Pixel std:  {}".format(self.pixel_std))
 
         # ----------------- Transforms -----------------
         if use_mosaic:
@@ -348,6 +351,9 @@ class RTDetrBaseTransform(object):
         self.pixel_mean = pixel_mean  # RGB format
         self.pixel_std = pixel_std    # RGB format
         self.color_format = 'rgb'
+        print("================= Pixel Statistics =================")
+        print("Pixel mean: {}".format(self.pixel_mean))
+        print("Pixel std:  {}".format(self.pixel_std))
 
         # ----------------- Transforms -----------------
         self.transform = Compose([

+ 2 - 0
dataset/voc.py

@@ -273,6 +273,8 @@ if __name__ == "__main__":
 
     trans_config = {
         'aug_type': args.aug_type,    # optional: ssd, yolov5
+        'pixel_mean': [0., 0., 0.],
+        'pixel_std':  [255., 255., 255.],
         # Basic Augment
         'degrees': 0.0,
         'translate': 0.2,

+ 2 - 0
dataset/widerface.py

@@ -207,6 +207,8 @@ if __name__ == "__main__":
 
     trans_config = {
         'aug_type': args.aug_type,    # optional: ssd, yolov5
+        'pixel_mean': [0., 0., 0.],
+        'pixel_std':  [255., 255., 255.],
         # Basic Augment
         'degrees': 0.0,
         'translate': 0.2,

+ 5 - 8
engine.py

@@ -1133,6 +1133,7 @@ class RTRTrainer(object):
         self.world_size = world_size
         self.grad_accumulate = args.grad_accumulate
         self.clip_grad = 0.1
+        self.args.fp16 = False   # No AMP for DETR
         # path to save model
         self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
         os.makedirs(self.path_to_save, exist_ok=True)
@@ -1140,10 +1141,9 @@ class RTRTrainer(object):
         # ---------------------------- Hyperparameters refer to RTMDet ----------------------------
         self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 1e-4, 'lr0': 0.0001, 'backbone_lr_ratio': 0.1}
         self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.1}
-        self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}        
 
         # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
-        self.data_cfg = data_cfg
+        self.data_cfg  = data_cfg
         self.model_cfg = model_cfg
         self.trans_cfg = trans_cfg
 
@@ -1248,14 +1248,11 @@ class RTRTrainer(object):
             # Warmup
             if ni <= nw:
                 xi = [0, nw]  # x interp
-                for j, x in enumerate(self.optimizer.param_groups):
-                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
-                    x['lr'] = np.interp( ni, xi, [0.0, x['initial_lr'] * self.lf(self.epoch)])
-                    if 'momentum' in x:
-                        x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
+                for x in self.optimizer.param_groups:
+                    x['lr'] = np.interp(ni, xi, [0.0, x['initial_lr'] * self.lf(self.epoch)])
                                 
             # To device
-            images = images.to(self.device, non_blocking=True).float() / 255.
+            images = images.to(self.device, non_blocking=True).float()
 
             # Multi scale
             if self.args.multi_scale:

+ 5 - 1
models/detectors/rtcdet/rtcdet_backbone.py

@@ -177,4 +177,8 @@ if __name__ == '__main__':
     flops, params = profile(model, inputs=(x, ), verbose=False)
     print('==============================')
     print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))
+    print('Params : {:.2f} M'.format(params / 1e6))
+
+
+    for n, p in model.named_parameters():
+        print(n)

+ 5 - 1
models/detectors/rtdetr/basic_modules/backbone.py

@@ -103,7 +103,7 @@ def build_scnetv2(cfg, pretrained_weight=None):
 if __name__ == '__main__':
     cfg = {
         'backbone':      'resnet18',
-        'backbone_norm': 'FrozeBN',
+        'backbone_norm': 'BN',
         'res5_dilation': False,
         'pretrained': True,
         'pretrained_weight': 'imagenet1k_v1',
@@ -115,3 +115,7 @@ if __name__ == '__main__':
     output = model(x)
     for y in output:
         print(y.size())
+
+    for n, p in model.named_parameters():
+        print(n.split(".")[-1])
+

+ 194 - 35
models/detectors/rtdetr/basic_modules/basic.py

@@ -1,10 +1,43 @@
+import math
 import torch
 import torch.nn as nn
 
 
+# ----------------- MLP modules -----------------
+class MLP(nn.Module):
+    def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
+        super().__init__()
+        self.num_layers = num_layers
+        h = [hidden_dim] * (num_layers - 1)
+        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([in_dim] + h, h + [out_dim]))
+
+    def forward(self, x):
+        for i, layer in enumerate(self.layers):
+            x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+        return x
+
+class FFN(nn.Module):
+    def __init__(self, d_model=256, mlp_ratio=4.0, dropout=0., act_type='relu'):
+        super().__init__()
+        self.fpn_dim = round(d_model * mlp_ratio)
+        self.linear1 = nn.Linear(d_model, self.fpn_dim)
+        self.activation = get_activation(act_type)
+        self.dropout2 = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(self.fpn_dim, d_model)
+        self.dropout3 = nn.Dropout(dropout)
+        self.norm = nn.LayerNorm(d_model)
+
+    def forward(self, src):
+        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
+        src = src + self.dropout3(src2)
+        src = self.norm(src)
+        
+        return src
+    
+
 # ----------------- CNN modules -----------------
-def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
-    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+def get_conv2d(c1, c2, k, p, s, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, groups=g, bias=bias)
 
     return conv
 
@@ -79,46 +112,172 @@ class FrozenBatchNorm2d(torch.nn.Module):
         bias = b - rm * scale
         return x * scale + bias
     
-class Conv(nn.Module):
+class BasicConv(nn.Module):
     def __init__(self, 
-                 c1,                   # in channels
-                 c2,                   # out channels 
-                 k=1,                  # kernel size 
-                 p=0,                  # padding
-                 s=1,                  # padding
-                 d=1,                  # dilation
-                 act_type  :str  = 'lrelu',   # activation
-                 norm_type :str  ='BN',       # normalization
-                 depthwise :bool =False):
-        super(Conv, self).__init__()
-        convs = []
+                 in_dim,                   # in channels
+                 out_dim,                  # out channels 
+                 kernel_size=1,            # kernel size 
+                 padding=0,                # padding
+                 stride=1,                 # padding
+                 act_type  :str = 'lrelu', # activation
+                 norm_type :str = 'BN',    # normalization
+                ):
+        super(BasicConv, self).__init__()
         add_bias = False if norm_type else True
-        if depthwise:
-            convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
-            # depthwise conv
-            if norm_type:
-                convs.append(get_norm(norm_type, c1))
-            if act_type:
-                convs.append(get_activation(act_type))
-            # pointwise conv
-            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
-            if norm_type:
-                convs.append(get_norm(norm_type, c2))
-            if act_type:
-                convs.append(get_activation(act_type))
+        self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
+        self.norm = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        return self.act(self.norm(self.conv(x)))
+
+class DepthwiseConv(nn.Module):
+    def __init__(self, 
+                 in_dim,                 # in channels
+                 out_dim,                # out channels 
+                 kernel_size=1,          # kernel size 
+                 padding=0,              # padding
+                 stride=1,               # padding
+                 act_type  :str = None,  # activation
+                 norm_type :str = 'BN',  # normalization
+                ):
+        super(DepthwiseConv, self).__init__()
+        assert in_dim == out_dim
+        add_bias = False if norm_type else True
+        self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=out_dim, bias=add_bias)
+        self.norm = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        return self.act(self.norm(self.conv(x)))
+
+class PointwiseConv(nn.Module):
+    def __init__(self, 
+                 in_dim,                   # in channels
+                 out_dim,                  # out channels 
+                 act_type  :str = 'lrelu', # activation
+                 norm_type :str = 'BN',    # normalization
+                ):
+        super(DepthwiseConv, self).__init__()
+        assert in_dim == out_dim
+        add_bias = False if norm_type else True
+        self.conv = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, g=1, bias=add_bias)
+        self.norm = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        return self.act(self.norm(self.conv(x)))
 
+## Yolov8's BottleNeck
+class Bottleneck(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expand_ratio = 0.5,
+                 kernel_sizes = [3, 3],
+                 shortcut     = True,
+                 act_type     = 'silu',
+                 norm_type    = 'BN',
+                 depthwise    = False,):
+        super(Bottleneck, self).__init__()
+        inter_dim = int(out_dim * expand_ratio)
+        if depthwise:
+            self.cv1 = nn.Sequential(
+                DepthwiseConv(in_dim, in_dim, kernel_size=kernel_sizes[0], padding=kernel_sizes[0]//2, act_type=act_type, norm_type=norm_type),
+                PointwiseConv(in_dim, inter_dim, act_type=act_type, norm_type=norm_type),
+            )
+            self.cv2 = nn.Sequential(
+                DepthwiseConv(inter_dim, inter_dim, kernel_size=kernel_sizes[1], padding=kernel_sizes[1]//2, act_type=act_type, norm_type=norm_type),
+                PointwiseConv(inter_dim, out_dim, act_type=act_type, norm_type=norm_type),
+            )
         else:
-            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
-            if norm_type:
-                convs.append(get_norm(norm_type, c2))
-            if act_type:
-                convs.append(get_activation(act_type))
-            
-        self.convs = nn.Sequential(*convs)
+            self.cv1 = BasicConv(in_dim, inter_dim,  kernel_size=kernel_sizes[0], padding=kernel_sizes[0]//2, act_type=act_type, norm_type=norm_type)
+            self.cv2 = BasicConv(inter_dim, out_dim, kernel_size=kernel_sizes[1], padding=kernel_sizes[1]//2, act_type=act_type, norm_type=norm_type)
+        self.shortcut = shortcut and in_dim == out_dim
 
+    def forward(self, x):
+        h = self.cv2(self.cv1(x))
+
+        return x + h if self.shortcut else h
+
+# Yolov8's StageBlock
+class RTCBlock(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 num_blocks = 1,
+                 shortcut   = False,
+                 act_type   = 'silu',
+                 norm_type  = 'BN',
+                 depthwise  = False,):
+        super(RTCBlock, self).__init__()
+        self.inter_dim = out_dim // 2
+        self.input_proj = BasicConv(in_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.m = nn.Sequential(*(
+            Bottleneck(self.inter_dim, self.inter_dim, 1.0, [3, 3], shortcut, act_type, norm_type, depthwise)
+            for _ in range(num_blocks)))
+        self.output_proj = BasicConv((2 + num_blocks) * self.inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
 
     def forward(self, x):
-        return self.convs(x)
+        # Input proj
+        x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
+        out = list([x1, x2])
+
+        # Bottlenecl
+        out.extend(m(out[-1]) for m in self.m)
+
+        # Output proj
+        out = self.output_proj(torch.cat(out, dim=1))
+
+        return out
 
 
 # ----------------- Transformer modules -----------------
+## Transformer layer
+class TransformerLayer(nn.Module):
+    def __init__(self,
+                 d_model         :int   = 256,
+                 num_heads       :int   = 8,
+                 mlp_ratio       :float = 4.0,
+                 dropout         :float = 0.1,
+                 act_type        :str   = "relu",
+                 ):
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.mlp_ratio = mlp_ratio
+        self.dropout = dropout
+        self.act_type = act_type
+        # ----------- Basic parameters -----------
+        # Multi-head Self-Attn
+        self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
+        self.dropout = nn.Dropout(dropout)
+        self.norm = nn.LayerNorm(d_model)
+
+        # Feedforwaed Network
+        self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
+
+    def with_pos_embed(self, tensor, pos):
+        return tensor if pos is None else tensor + pos
+
+
+    def forward(self, src, pos):
+        """
+        Input:
+            src: [torch.Tensor] -> [B, N, C]
+            pos: [torch.Tensor] -> [B, N, C]
+        Output:
+            src: [torch.Tensor] -> [B, N, C]
+        """
+        q = k = self.with_pos_embed(src, pos)
+
+        # -------------- MHSA --------------
+        src2 = self.self_attn(q, k, value=src)
+        src = src + self.dropout(src2)
+        src = self.norm(src)
+
+        # -------------- FFN --------------
+        src = self.ffn(src)
+        
+        return src

+ 51 - 44
models/detectors/rtdetr/basic_modules/pafpn.py

@@ -2,7 +2,7 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from .basic import Conv, RTCBlock
+from .basic import BasicConv, RTCBlock
 
 
 # Build PaFPN
@@ -12,29 +12,38 @@ def build_pafpn(cfg, in_dims, out_dim):
 
 # ----------------- Feature Pyramid Network -----------------
 ## Real-time Convolutional PaFPN
-class RTCPaFPN(nn.Module):
+class HybridEncoder(nn.Module):
     def __init__(self, 
                  in_dims   = [256, 512, 512],
+                 out_dim   = 256,
                  width     = 1.0,
                  depth     = 1.0,
-                 ratio     = 1.0,
                  act_type  = 'silu',
                  norm_type = 'BN',
                  depthwise = False):
-        super(RTCPaFPN, self).__init__()
+        super(HybridEncoder, self).__init__()
         print('==============================')
         print('FPN: {}'.format("RTC-PaFPN"))
         # ---------------- Basic parameters ----------------
         self.in_dims = in_dims
+        self.out_dim = round(out_dim * width)
         self.width = width
         self.depth = depth
-        self.out_dim = [round(256 * width), round(512 * width), round(512 * width * ratio)]
         c3, c4, c5 = in_dims
 
-        # ---------------- Top dwon ----------------
+        # ---------------- Input projs ----------------
+        self.reduce_layer_1 = BasicConv(c5, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.reduce_layer_2 = BasicConv(c4, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.reduce_layer_3 = BasicConv(c3, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+
+        # ---------------- Downsample ----------------
+        self.dowmsample_layer_1 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
+        self.dowmsample_layer_2 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
+
+        # ---------------- Top dwon FPN ----------------
         ## P5 -> P4
-        self.top_down_layer_1 = RTCBlock(in_dim       = c5 + c4,
-                                         out_dim      = round(512*width),
+        self.top_down_layer_1 = RTCBlock(in_dim       = self.out_dim * 2,
+                                         out_dim      = self.out_dim,
                                          num_blocks   = round(3*depth),
                                          shortcut     = False,
                                          act_type     = act_type,
@@ -42,29 +51,28 @@ class RTCPaFPN(nn.Module):
                                          depthwise    = depthwise,
                                          )
         ## P4 -> P3
-        self.top_down_layer_2 = RTCBlock(in_dim       = round(512*width) + c3,
-                                         out_dim      = round(256*width),
+        self.top_down_layer_2 = RTCBlock(in_dim       = self.out_dim * 2,
+                                         out_dim      = self.out_dim,
                                          num_blocks   = round(3*depth),
                                          shortcut     = False,
                                          act_type     = act_type,
                                          norm_type    = norm_type,
                                          depthwise    = depthwise,
                                          )
-        # ---------------- Bottom up ----------------
+        
+        # ---------------- Bottom up PAN----------------
         ## P3 -> P4
-        self.dowmsample_layer_1 = Conv(round(256*width), round(256*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.bottom_up_layer_1  = RTCBlock(in_dim       = round(256*width) + round(512*width),
-                                           out_dim      = round(512*width),
-                                           num_blocks   = round(3*depth),
-                                           shortcut     = False,
-                                           act_type     = act_type,
-                                           norm_type    = norm_type,
-                                           depthwise    = depthwise,
-                                           )
+        self.bottom_up_layer_1 = RTCBlock(in_dim       = self.out_dim * 2,
+                                          out_dim      = self.out_dim,
+                                          num_blocks   = round(3*depth),
+                                          shortcut     = False,
+                                          act_type     = act_type,
+                                          norm_type    = norm_type,
+                                          depthwise    = depthwise,
+                                          )
         ## P4 -> P5
-        self.dowmsample_layer_2 = Conv(round(512*width), round(512*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.bottom_up_layer_2 = RTCBlock(in_dim       = round(512 * width) + c5,
-                                          out_dim      = round(512 * width * ratio),
+        self.bottom_up_layer_2 = RTCBlock(in_dim       = self.out_dim * 2,
+                                          out_dim      = self.out_dim,
                                           num_blocks   = round(3*depth),
                                           shortcut     = False,
                                           act_type     = act_type,
@@ -85,26 +93,25 @@ class RTCPaFPN(nn.Module):
     def forward(self, features):
         c3, c4, c5 = features
 
-        # Top down
-        ## P5 -> P4
-        c6 = F.interpolate(c5, scale_factor=2.0)
-        c7 = torch.cat([c6, c4], dim=1)
-        c8 = self.top_down_layer_1(c7)
-        ## P4 -> P3
-        c9 = F.interpolate(c8, scale_factor=2.0)
-        c10 = torch.cat([c9, c3], dim=1)
-        c11 = self.top_down_layer_2(c10)
-
-        # Bottom up
-        # p3 -> P4
-        c12 = self.dowmsample_layer_1(c11)
-        c13 = torch.cat([c12, c8], dim=1)
-        c14 = self.bottom_up_layer_1(c13)
-        # P4 -> P5
-        c15 = self.dowmsample_layer_2(c14)
-        c16 = torch.cat([c15, c5], dim=1)
-        c17 = self.bottom_up_layer_2(c16)
-
-        out_feats = [c11, c14, c17] # [P3, P4, P5]
+        # -------- Input projs --------
+        p5 = self.reduce_layer_1(c5)
+        p4 = self.reduce_layer_2(c4)
+        p3 = self.reduce_layer_3(c3)
+
+        # -------- Top down FPN --------
+        p5_up = F.interpolate(p5, scale_factor=2.0)
+        p4 = self.top_down_layer_1(torch.cat([p4, p5_up], dim=1))
+
+        p4_up = F.interpolate(p4, scale_factor=2.0)
+        p3 = self.top_down_layer_2(torch.cat([p3, p4_up], dim=1))
+
+        # -------- Bottom up PAN --------
+        p3_ds = self.dowmsample_layer_1(p3)
+        p4 = self.bottom_up_layer_1(torch.cat([p4, p3_ds], dim=1))
+
+        p4_ds = self.dowmsample_layer_2(p4)
+        p5 = self.bottom_up_layer_2(torch.cat([p5, p4_ds], dim=1))
+
+        out_feats = [p3, p4, p5]
         
         return out_feats

+ 39 - 13
utils/solver/optimizer.py

@@ -50,20 +50,46 @@ def build_detr_optimizer(cfg, model, resume=None):
     print('--base lr: {}'.format(cfg['lr0']))
     print('--weight_decay: {}'.format(cfg['weight_decay']))
 
-    param_dicts = [
-        {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
-        {
-            "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
-            "lr": cfg['lr0'] * cfg['backbone_lr_ratio'],
-        },
-    ]
+    # param_dicts = [
+    #     {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
+    #     {
+    #         "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
+    #         "lr": cfg['lr0'] * cfg['backbone_lr_ratio'],
+    #     },
+    # ]
 
-    if cfg['optimizer'] == 'adam':
-        optimizer = torch.optim.Adam(param_dicts, lr=cfg['lr0'], weight_decay=cfg['weight_decay'])
-    elif cfg['optimizer'] == 'adamw':
-        optimizer = torch.optim.AdamW(param_dicts, lr=cfg['lr0'], weight_decay=cfg['weight_decay'])
-    else:
-        raise NotImplementedError('Optimizer {} not implemented.'.format(cfg['optimizer']))
+
+    # ------------- Divide model's parameters -------------
+    param_dicts = [], [], [], [], [], []
+    for n, p in model.named_parameters():
+        # Non-Backbone's learnable parameters
+        if "backbone" not in n and p.requires_grad:
+            if "bias" == n.split(".")[-1]:
+                param_dicts[0].append(p)      # no weight decay for all layers' bias
+            else:
+                if "norm" == n.split(".")[-2]:
+                    param_dicts[1].append(p)  # no weight decay for all NormLayers' weight
+                else:
+                    param_dicts[2].append(p)  # weight decay for all Non-NormLayers' weight
+        # Backbone's learnable parameters
+        elif "backbone" in n and p.requires_grad:
+            if "bias" == n.split(".")[-1]:
+                param_dicts[3].append(p)      # no weight decay for all layers' bias
+            else:
+                if "norm" == n.split(".")[-2]:
+                    param_dicts[4].append(p)  # no weight decay for all NormLayers' weight
+                else:
+                    param_dicts[5].append(p)  # weight decay for all Non-NormLayers' weight
+
+    # Non-Backbone's learnable parameters
+    optimizer = torch.optim.AdamW(param_dicts[0], lr=cfg['lr0'], weight_decay=0.0)
+    optimizer.add_param_group({"params": param_dicts[1], "weight_decay": 0.0})
+    optimizer.add_param_group({"params": param_dicts[2], "weight_decay": cfg['weight_decay']})
+
+    # Backbone's learnable parameters
+    optimizer.add_param_group({"params": param_dicts[3], "lr": cfg['lr0'] * cfg['backbone_lr_ratio'], "weight_decay": 0.0})
+    optimizer.add_param_group({"params": param_dicts[4], "lr": cfg['lr0'] * cfg['backbone_lr_ratio'], "weight_decay": 0.0})
+    optimizer.add_param_group({"params": param_dicts[5], "lr": cfg['lr0'] * cfg['backbone_lr_ratio'], "weight_decay": cfg['weight_decay']})
 
     start_epoch = 0
     if resume and resume != 'None':