yjh0410 2 vuotta sitten
vanhempi
sitoutus
ebddefb2f0

+ 3 - 3
config/model_config/rtcdet_v2_config.py

@@ -15,7 +15,7 @@ rtcdet_v2_cfg = {
         'stride': [8, 16, 32],  # P3, P4, P5
         'max_stride': 32,
         ## Neck: SPP
-        'neck': 'csp_sppf',
+        'neck': 'sppf',
         'neck_expand_ratio': 0.5,
         'pooling_size': 5,
         'neck_act': 'silu',
@@ -26,7 +26,7 @@ rtcdet_v2_cfg = {
         'fpn_reduce_layer': 'conv',
         'fpn_downsample_layer': 'conv',
         'fpn_core_block': 'elan_block',
-        'fpn_squeeze_ratio': 0.25,
+        'fpn_expand_ratio': 0.25,
         'fpn_act': 'silu',
         'fpn_norm': 'BN',
         'fpn_depthwise': False,
@@ -41,7 +41,7 @@ rtcdet_v2_cfg = {
         # ---------------- Train config ----------------
         ## Input
         'multi_scale': [0.5, 1.25],   # 320 -> 800
-        'trans_type': 'rtcdet_v1_large',
+        'trans_type': 'yolox_large',
         # ---------------- Assignment config ----------------
         ## Matcher
         'matcher': {'tal': {'topk': 10,

+ 8 - 0
engine.py

@@ -938,6 +938,14 @@ class RTMTrainer(object):
             print(' - Close < perspective of rotation > ...')
             self.trans_cfg['perspective'] = 0.0
 
+        # close random affine
+        if 'translate' in self.trans_cfg.keys() and self.trans_cfg['translate'] > 0.0:
+            print(' - Close < translate of affine > ...')
+            self.trans_cfg['translate'] = 0.0
+        if 'scale' in self.trans_cfg.keys():
+            print(' - Close < scale of affine >...')
+            self.trans_cfg['scale'] = [1.0, 1.0]
+
         # build a new transform for second stage
         print(' - Rebuild transforms ...')
         self.train_transform, self.trans_cfg = build_transform(

+ 11 - 11
models/detectors/rtcdet_v2/rtcdet_v2_backbone.py

@@ -25,10 +25,10 @@ class ELANNetv2(nn.Module):
         ## scale factor
         self.width = width
         self.depth = depth
-        self.squeeze_ratio = [0.5, 0.5, 0.375, 0.25]
+        self.expand_ratio = [0.5, 0.5, 0.5, 0.25]
         ## pyramid feats
-        self.feat_dims = [round(dim * width) for dim in [64, 128, 256, 512, 1024]]
-        self.branch_depths = [round(dep * depth) for dep in [3, 6, 6, 3]]
+        self.feat_dims = [round(dim * width) for dim in [64, 128, 256, 512, 1024, 1024]]
+        self.branch_depths = [round(dep * depth) for dep in [3, 3, 3, 3]]
         ## nonlinear
         self.act_type = act_type
         self.norm_type = norm_type
@@ -42,23 +42,23 @@ class ELANNetv2(nn.Module):
         )
         ## P2/4
         self.layer_2 = nn.Sequential(   
-            DSBlock(self.feat_dims[0], self.feat_dims[1], self.act_type, self.norm_type, self.depthwise),
-            ELAN_Stage(self.feat_dims[1], self.feat_dims[1], self.squeeze_ratio[0], self.branch_depths[0], True, self.act_type, self.norm_type, self.depthwise)
+            DSBlock(self.feat_dims[0], self.feat_dims[1], act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise),
+            ELAN_Stage(self.feat_dims[1], self.feat_dims[2], self.expand_ratio[0], self.branch_depths[0], True, self.act_type, self.norm_type, self.depthwise)
         )
         ## P3/8
         self.layer_3 = nn.Sequential(
-            DSBlock(self.feat_dims[1], self.feat_dims[2], self.act_type, self.norm_type, self.depthwise),
-            ELAN_Stage(self.feat_dims[2], self.feat_dims[2], self.squeeze_ratio[1], self.branch_depths[1], True, self.act_type, self.norm_type, self.depthwise)
+            Conv(self.feat_dims[2], self.feat_dims[2], k=3, p=1, s=2, act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise),
+            ELAN_Stage(self.feat_dims[2], self.feat_dims[3], self.expand_ratio[1], self.branch_depths[1], True, self.act_type, self.norm_type, self.depthwise)
         )
         ## P4/16
         self.layer_4 = nn.Sequential(
-            DSBlock(self.feat_dims[2], self.feat_dims[3], self.act_type, self.norm_type, self.depthwise),
-            ELAN_Stage(self.feat_dims[3], self.feat_dims[3], self.squeeze_ratio[2], self.branch_depths[2], True, self.act_type, self.norm_type, self.depthwise)
+            Conv(self.feat_dims[3], self.feat_dims[3], k=3, p=1, s=2, act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise),
+            ELAN_Stage(self.feat_dims[3], self.feat_dims[4], self.expand_ratio[2], self.branch_depths[2], True, self.act_type, self.norm_type, self.depthwise)
         )
         ## P5/32
         self.layer_5 = nn.Sequential(
-            DSBlock(self.feat_dims[3], self.feat_dims[4], self.act_type, self.norm_type, self.depthwise),
-            ELAN_Stage(self.feat_dims[4], self.feat_dims[4], self.squeeze_ratio[3], self.branch_depths[3], True, self.act_type, self.norm_type, self.depthwise)
+            Conv(self.feat_dims[4], self.feat_dims[4], k=3, p=1, s=2, act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise),
+            ELAN_Stage(self.feat_dims[4], self.feat_dims[5], self.expand_ratio[3], self.branch_depths[3], True, self.act_type, self.norm_type, self.depthwise)
         )
 
 

+ 20 - 23
models/detectors/rtcdet_v2/rtcdet_v2_basic.py

@@ -1,3 +1,4 @@
+from typing import List
 import numpy as np
 import torch
 import torch.nn as nn
@@ -156,11 +157,12 @@ class YoloBottleneck(nn.Module):
     def __init__(self,
                  in_dim,
                  out_dim,
-                 expand_ratio=0.5,
-                 shortcut=False,
-                 act_type='silu',
-                 norm_type='BN',
-                 depthwise=False):
+                 kernel_sizes :List[int] = [3, 3],
+                 expand_ratio :float     = 0.5,
+                 shortcut     :bool      = False,
+                 act_type     :str       = 'silu',
+                 norm_type    :str       = 'BN',
+                 depthwise    :bool      = False):
         super(YoloBottleneck, self).__init__()
         # ------------------ Basic parameters ------------------
         self.in_dim = in_dim
@@ -168,8 +170,8 @@ class YoloBottleneck(nn.Module):
         self.inter_dim = int(out_dim * expand_ratio)
         self.shortcut = shortcut and in_dim == out_dim
         # ------------------ Network parameters ------------------
-        self.cv1 = Conv(in_dim, self.inter_dim, k=1, norm_type=norm_type, act_type=act_type)
-        self.cv2 = Conv(self.inter_dim, out_dim, k=3, p=1, norm_type=norm_type, act_type=act_type, depthwise=depthwise)
+        self.cv1 = Conv(in_dim, self.inter_dim, k=kernel_sizes[0], p=kernel_sizes[0]//2, norm_type=norm_type, act_type=act_type, depthwise=depthwise)
+        self.cv2 = Conv(self.inter_dim, out_dim, k=kernel_sizes[1], p=kernel_sizes[1]//2, norm_type=norm_type, act_type=act_type, depthwise=depthwise)
 
     def forward(self, x):
         h = self.cv2(self.cv1(x))
@@ -180,23 +182,23 @@ class YoloBottleneck(nn.Module):
 # ---------------------------- Base Modules ----------------------------
 ## ELAN Stage of Backbone
 class ELAN_Stage(nn.Module):
-    def __init__(self, in_dim, out_dim, squeeze_ratio :float=0.5, branch_depth :int=1, shortcut=False, act_type='silu', norm_type='BN', depthwise=False):
+    def __init__(self, in_dim, out_dim, expand_ratio :float=0.5, branch_depth :int=1, shortcut=False, act_type='silu', norm_type='BN', depthwise=False):
         super().__init__()
         # ----------- Basic Parameters -----------
         self.in_dim = in_dim
         self.out_dim = out_dim
-        self.inter_dim = round(in_dim * squeeze_ratio)
-        self.squeeze_ratio = squeeze_ratio
+        self.inter_dim = round(in_dim * expand_ratio)
+        self.expand_ratio = expand_ratio
         self.branch_depth = branch_depth
         # ----------- Network Parameters -----------
         self.cv1 = Conv(in_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
         self.cv2 = Conv(in_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
         self.cv3 = nn.Sequential(*[
-            YoloBottleneck(self.inter_dim, self.inter_dim, 1.0, shortcut, act_type, norm_type, depthwise)
+            YoloBottleneck(self.inter_dim, self.inter_dim, [1, 3], 1.0, shortcut, act_type, norm_type, depthwise)
             for _ in range(branch_depth)
         ])
         self.cv4 = nn.Sequential(*[
-            YoloBottleneck(self.inter_dim, self.inter_dim, 1.0, shortcut, act_type, norm_type, depthwise)
+            YoloBottleneck(self.inter_dim, self.inter_dim, [1, 3], 1.0, shortcut, act_type, norm_type, depthwise)
             for _ in range(branch_depth)
         ])
         ## output
@@ -217,17 +219,12 @@ class DSBlock(nn.Module):
         super().__init__()
         self.in_dim = in_dim
         self.out_dim = out_dim
-        self.inter_dim = out_dim // 2
         # branch-1
-        self.maxpool = nn.Sequential(
-            Conv(in_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type),
-            nn.MaxPool2d((2, 2), 2)
-        )
+        self.maxpool = nn.MaxPool2d((2, 2), 2)
         # branch-2
-        self.ds_conv = nn.Sequential(
-            Conv(in_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type),
-            Conv(self.inter_dim, self.inter_dim, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        ) 
+        self.ds_conv = Conv(in_dim, in_dim, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        # output
+        self.out_conv = Conv(in_dim*2, out_dim, k=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
 
 
     def forward(self, x):
@@ -236,7 +233,7 @@ class DSBlock(nn.Module):
         # branch-2
         x2 = self.ds_conv(x)
         # out-proj
-        out = torch.cat([x1, x2], dim=1)
+        out = self.out_conv(torch.cat([x1, x2], dim=1))
 
         return out
 
@@ -247,7 +244,7 @@ def build_fpn_block(cfg, in_dim, out_dim):
     if cfg['fpn_core_block'] == 'elan_block':
         layer = ELAN_Stage(in_dim        = in_dim,
                            out_dim       = out_dim,
-                           squeeze_ratio = cfg['fpn_squeeze_ratio'],
+                           expand_ratio  = cfg['fpn_expand_ratio'],
                            branch_depth  = round(3 * cfg['depth']),
                            shortcut      = False,
                            act_type      = cfg['fpn_act'],

+ 35 - 2
models/detectors/rtcdet_v2/rtcdet_v2_neck.py

@@ -1,7 +1,10 @@
 import torch
 import torch.nn as nn
 
-from .rtcdet_v2_basic import Conv
+try:
+    from .rtcdet_v2_basic import Conv
+except:
+    from rtcdet_v2_basic import Conv
 
 
 # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
@@ -60,4 +63,34 @@ def build_neck(cfg, in_dim, out_dim):
         neck = SPPFBlockCSP(cfg, in_dim, out_dim, cfg['neck_expand_ratio'])
 
     return neck
-        
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        ## Neck: SPP
+        'neck': 'csp_sppf',
+        'neck_expand_ratio': 0.5,
+        'pooling_size': 5,
+        'neck_act': 'silu',
+        'neck_norm': 'BN',
+        'neck_depthwise': False,
+    }
+    in_dim = 2048
+    out_dim = 2048
+    # Head-1
+    model = build_neck(cfg, in_dim, out_dim)
+    feat = torch.randn(1, in_dim, 20, 20)
+    t0 = time.time()
+    outputs = model(feat)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(feat, ), verbose=False)
+    print('==============================')
+    print('FPN: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('FPN: Params : {:.2f} M'.format(params / 1e6))

+ 9 - 4
models/detectors/rtcdet_v2/rtcdet_v2_pafpn.py

@@ -14,8 +14,12 @@ class RTCDetPaFPN(nn.Module):
         super(RTCDetPaFPN, self).__init__()
         # --------------------------- Basic Parameters ---------------------------
         self.in_dims = in_dims
-        self.fpn_dims = in_dims
-        
+        self.fpn_dims = [round(256*cfg['width']), round(512*cfg['width']), round(1024*cfg['width'])]
+
+        # --------------------------- Input proj ---------------------------
+        self.input_projs = nn.ModuleList([nn.Conv2d(in_dim, fpn_dim, kernel_size=1)
+                                          for in_dim, fpn_dim in zip(in_dims, self.fpn_dims)])
+                
         # --------------------------- Top-down FPN ---------------------------
         ## P5 -> P4
         self.reduce_layer_1 = build_reduce_layer(cfg, self.fpn_dims[2], self.fpn_dims[2]//2)
@@ -46,6 +50,7 @@ class RTCDetPaFPN(nn.Module):
 
 
     def forward(self, fpn_feats):
+        fpn_feats = [layer(feat) for feat, layer in zip(fpn_feats, self.input_projs)]
         c3, c4, c5 = fpn_feats
 
         # Top down
@@ -101,12 +106,12 @@ if __name__ == '__main__':
         'fpn_reduce_layer': 'conv',
         'fpn_downsample_layer': 'conv',
         'fpn_core_block': 'elan_block',
-        'fpn_squeeze_ratio': 0.25,
+        'fpn_expand_ratio': 0.25,
         'fpn_act': 'silu',
         'fpn_norm': 'BN',
         'fpn_depthwise': False,
     }
-    fpn_dims = [256, 512, 1024]
+    fpn_dims = [512, 1024, 1024]
     out_dim = 256
     # Head-1
     model = build_fpn(cfg, fpn_dims, out_dim)