Bladeren bron

update RTCDet

yjh0410 1 jaar geleden
bovenliggende
commit
1c33255e13

+ 0 - 1
config/model_config/rtcdet_config.py

@@ -10,7 +10,6 @@ rtcdet_cfg = {
         'bk_depthwise': False,
         'width': 0.50,
         'depth': 0.34,
-        'ratio': 2.0,
         'stride': [8, 16, 32],  # P3, P4, P5
         'max_stride': 32,
         'reg_max': 16,

+ 5 - 6
models/detectors/rtcdet/rtcdet_backbone.py

@@ -10,11 +10,11 @@ except:
 # ---------------------------- Basic functions ----------------------------
 ## YOLOv8's backbone
 class RTCBackbone(nn.Module):
-    def __init__(self, width=1.0, depth=1.0, ratio=1.0, act_type='silu', norm_type='BN', depthwise=False):
+    def __init__(self, width=1.0, depth=1.0, act_type='silu', norm_type='BN', depthwise=False):
         super(RTCBackbone, self).__init__()
-        self.feat_dims = [round(64 * width), round(128 * width), round(256 * width), round(512 * width), round(512 * width * ratio)]
+        self.feat_dims = [round(64 * width), round(128 * width), round(256 * width), round(512 * width), round(1024 * width)]
         # P1/2
-        self.layer_1 = BasicConv(3, self.feat_dims[0], kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
+        self.layer_1 = BasicConv(3, self.feat_dims[0], kernel_size=6, padding=2, stride=2, act_type=act_type, norm_type=norm_type)
         # P2/4
         self.layer_2 = nn.Sequential(
             BasicConv(self.feat_dims[0], self.feat_dims[1],
@@ -35,7 +35,7 @@ class RTCBackbone(nn.Module):
                       act_type=act_type, norm_type=norm_type, depthwise=depthwise),
             RTCBlock(in_dim     = self.feat_dims[2],
                      out_dim    = self.feat_dims[2],
-                     num_blocks = round(6*depth),
+                     num_blocks = round(9*depth),
                      shortcut   = True,
                      act_type   = act_type,
                      norm_type  = norm_type,
@@ -48,7 +48,7 @@ class RTCBackbone(nn.Module):
                       act_type=act_type, norm_type=norm_type, depthwise=depthwise),
             RTCBlock(in_dim     = self.feat_dims[3],
                      out_dim    = self.feat_dims[3],
-                     num_blocks = round(6*depth),
+                     num_blocks = round(9*depth),
                      shortcut   = True,
                      act_type   = act_type,
                      norm_type  = norm_type,
@@ -96,7 +96,6 @@ def build_backbone(cfg):
     # model
     backbone = RTCBackbone(width=cfg['width'],
                            depth=cfg['depth'],
-                           ratio=cfg['ratio'],
                            act_type=cfg['bk_act'],
                            norm_type=cfg['bk_norm'],
                            depthwise=cfg['bk_depthwise']

+ 1 - 1
models/detectors/rtcdet/rtcdet_basic.py

@@ -109,7 +109,7 @@ class RTCBlock(nn.Module):
         self.inter_dim = out_dim // 2
         self.input_proj = BasicConv(in_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
         self.m = nn.ModuleList([
-            Bottleneck(self.inter_dim, self.inter_dim, 1.0, [3, 3], shortcut, act_type, norm_type, depthwise)
+            Bottleneck(self.inter_dim, self.inter_dim, 1.0, [1, 3], shortcut, act_type, norm_type, depthwise)
             for _ in range(num_blocks)])
         self.output_proj = BasicConv((2 + num_blocks) * self.inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
 

+ 2 - 2
models/detectors/rtcdet/rtcdet_head.py

@@ -136,9 +136,9 @@ if __name__ == '__main__':
         'head_depthwise': False,
         'reg_max': 16,
     }
-    fpn_dims = [256, 512, 512]
+    fpn_dims = [256, 256, 256]
     cls_out_dim = 256
-    reg_out_dim = 64
+    reg_out_dim = 256
     # Head-1
     model = build_head(cfg, fpn_dims, num_levels=3, num_classes=80, reg_max=16)
     print(model)

+ 51 - 43
models/detectors/rtcdet/rtcdet_pafpn.py

@@ -11,11 +11,10 @@ except:
 # PaFPN-ELAN
 class RTCPaFPN(nn.Module):
     def __init__(self, 
-                 in_dims   = [256, 512, 512],
+                 in_dims   = [256, 512, 1024],
                  out_dim   = 256,
                  width     = 1.0,
                  depth     = 1.0,
-                 ratio     = 1.0,
                  act_type  = 'silu',
                  norm_type = 'BN',
                  depthwise = False):
@@ -28,9 +27,12 @@ class RTCPaFPN(nn.Module):
         self.depth = depth
         c3, c4, c5 = in_dims
 
-        # ---------------- Top dwon FPN----------------
+        # ---------------- Top-dwon FPN----------------
         ## P5 -> P4
-        self.top_down_layer_1 = RTCBlock(in_dim      = c5 + c4,
+        self.reduce_layer_1   = BasicConv(c5, round(512*width),
+                                          kernel_size=1, padding=0, stride=1,
+                                          act_type=act_type, norm_type=norm_type)
+        self.top_down_layer_1 = RTCBlock(in_dim      = round(512*width) + c4,
                                          out_dim     = round(512*width),
                                          num_blocks  = round(3*depth),
                                          shortcut    = False,
@@ -38,8 +40,12 @@ class RTCPaFPN(nn.Module):
                                          norm_type   = norm_type,
                                          depthwise   = depthwise,
                                          )
+
         ## P4 -> P3
-        self.top_down_layer_2 = RTCBlock(in_dim      = round(512*width) + c3,
+        self.reduce_layer_2   = BasicConv(round(512*width), round(256*width),
+                                          kernel_size=1, padding=0, stride=1,
+                                          act_type=act_type, norm_type=norm_type)
+        self.top_down_layer_2 = RTCBlock(in_dim      = round(256*width) + c3,
                                          out_dim     = round(256*width),
                                          num_blocks  = round(3*depth),
                                          shortcut    = False,
@@ -47,38 +53,39 @@ class RTCPaFPN(nn.Module):
                                          norm_type   = norm_type,
                                          depthwise   = depthwise,
                                          )
-        
-        # ---------------- Bottom up PAN----------------
+
+        # ---------------- Bottom-up PAN ----------------
         ## P3 -> P4
         self.dowmsample_layer_1 = BasicConv(round(256*width), round(256*width),
                                             kernel_size=3, padding=1, stride=2,
                                             act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.bottom_up_layer_1 = RTCBlock(in_dim      = round(256*width) + round(512*width),
-                                          out_dim     = round(512*width),
-                                          num_blocks  = round(3*depth),
-                                          shortcut    = False,
-                                          act_type    = act_type,
-                                          norm_type   = norm_type,
-                                          depthwise   = depthwise,
-                                          )
+        self.bottom_up_layer_1  = RTCBlock(in_dim      = round(256*width) + round(256*width),
+                                           out_dim     = round(512*width),
+                                           num_blocks  = round(3*depth),
+                                           shortcut    = False,
+                                           act_type    = act_type,
+                                           norm_type   = norm_type,
+                                           depthwise   = depthwise,
+                                           )
+
         ## P4 -> P5
         self.dowmsample_layer_2 = BasicConv(round(512*width), round(512*width),
                                             kernel_size=3, padding=1, stride=2,
                                             act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.bottom_up_layer_2 = RTCBlock(in_dim      = round(512 * width) + c5,
-                                          out_dim     = round(512 * width * ratio),
-                                          num_blocks  = round(3*depth),
-                                          shortcut    = False,
-                                          act_type    = act_type,
-                                          norm_type   = norm_type,
-                                          depthwise   = depthwise,
-                                          )
+        self.bottom_up_layer_2  = RTCBlock(in_dim      = round(512*width) + round(512*width),
+                                           out_dim     = round(1024*width),
+                                           num_blocks  = round(3*depth),
+                                           shortcut    = False,
+                                           act_type    = act_type,
+                                           norm_type   = norm_type,
+                                           depthwise   = depthwise,
+                                           )
 
         # ---------------- Output projection ----------------
         ## Output projs
         self.out_layers = nn.ModuleList([
             BasicConv(in_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-            for in_dim in [round(256*width), round(512*width), round(512*width * ratio)]
+            for in_dim in [round(256*width), round(512*width), round(1024*width)]
             ])
         self.out_dims = [out_dim] * 3
 
@@ -97,26 +104,28 @@ class RTCPaFPN(nn.Module):
 
         # Top down
         ## P5 -> P4
-        c6 = F.interpolate(c5, scale_factor=2.0)
-        c7 = torch.cat([c6, c4], dim=1)
-        c8 = self.top_down_layer_1(c7)
+        c6 = self.reduce_layer_1(c5)
+        c7 = F.interpolate(c6, scale_factor=2.0)
+        c8 = torch.cat([c7, c4], dim=1)
+        c9 = self.top_down_layer_1(c8)
         ## P4 -> P3
-        c9 = F.interpolate(c8, scale_factor=2.0)
-        c10 = torch.cat([c9, c3], dim=1)
-        c11 = self.top_down_layer_2(c10)
+        c10 = self.reduce_layer_2(c9)
+        c11 = F.interpolate(c10, scale_factor=2.0)
+        c12 = torch.cat([c11, c3], dim=1)
+        c13 = self.top_down_layer_2(c12)
 
         # Bottom up
-        # p3 -> P4
-        c12 = self.dowmsample_layer_1(c11)
-        c13 = torch.cat([c12, c8], dim=1)
-        c14 = self.bottom_up_layer_1(c13)
-        # P4 -> P5
-        c15 = self.dowmsample_layer_2(c14)
-        c16 = torch.cat([c15, c5], dim=1)
-        c17 = self.bottom_up_layer_2(c16)
-
-        out_feats = [c11, c14, c17] # [P3, P4, P5]
+        ## p3 -> P4
+        c14 = self.dowmsample_layer_1(c13)
+        c15 = torch.cat([c14, c10], dim=1)
+        c16 = self.bottom_up_layer_1(c15)
+        ## P4 -> P5
+        c17 = self.dowmsample_layer_2(c16)
+        c18 = torch.cat([c17, c6], dim=1)
+        c19 = self.bottom_up_layer_2(c18)
 
+        out_feats = [c13, c16, c19] # [P3, P4, P5]
+        
         # output proj layers
         out_feats_proj = []
         for feat, layer in zip(out_feats, self.out_layers):
@@ -132,7 +141,6 @@ def build_fpn(cfg, in_dims, out_dim):
                            out_dim   = out_dim,
                            width     = cfg['width'],
                            depth     = cfg['depth'],
-                           ratio     = cfg['ratio'],
                            act_type  = cfg['fpn_act'],
                            norm_type = cfg['fpn_norm'],
                            depthwise = cfg['fpn_depthwise']
@@ -154,8 +162,8 @@ if __name__ == '__main__':
         'depth': 1.0,
         'ratio': 1.0,
     }
-    model = build_fpn(cfg, in_dims=[256, 512, 512], out_dim=256)
-    pyramid_feats = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 512, 20, 20)]
+    model = build_fpn(cfg, in_dims=[256, 512, 1024], out_dim=256)
+    pyramid_feats = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
     t0 = time.time()
     outputs = model(pyramid_feats)
     t1 = time.time()

+ 5 - 1
models/detectors/rtdetr/basic_modules/basic.py

@@ -214,7 +214,11 @@ class BasicConv(nn.Module):
         if not self.depthwise:
             return self.act(self.norm(self.conv(x)))
         else:
-            return self.act(self.norm2(self.conv2(self.norm1(self.conv1(x)))))
+            # Depthwise conv
+            x = self.norm1(self.conv1(x))
+            # Pointwise conv
+            x = self.norm2(self.conv2(x))
+            return x
 
 
 # ----------------- CNN Modules -----------------

+ 36 - 1
models/detectors/yolox/yolox_head.py

@@ -1,7 +1,10 @@
 import torch
 import torch.nn as nn
 
-from .yolox_basic import Conv
+try:
+    from .yolox_basic import Conv
+except:
+    from  yolox_basic import Conv
 
 
 class DecoupledHead(nn.Module):
@@ -71,3 +74,35 @@ def build_head(cfg, in_dim, out_dim, num_classes=80):
     head = DecoupledHead(cfg, in_dim, out_dim, num_classes) 
 
     return head
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'head': 'decoupled_head',
+        'num_cls_head': 2,
+        'num_reg_head': 2,
+        'head_act': 'silu',
+        'head_norm': 'BN',
+        'head_depthwise': False,
+    }
+    fpn_dims = [256, 256, 256]
+    cls_out_dim = 256
+    reg_out_dim = 256
+    # Head-1
+    model = build_head(cfg, 256, 256, num_classes=80)
+    print(model)
+    fpn_feats = [torch.randn(1, fpn_dims[0], 80, 80), torch.randn(1, fpn_dims[1], 40, 40), torch.randn(1, fpn_dims[2], 20, 20)]
+    t0 = time.time()
+    outputs = model(fpn_feats)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(fpn_feats, ), verbose=False)
+    print('==============================')
+    print('Head-1: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Head-1: Params : {:.2f} M'.format(params / 1e6))

+ 35 - 2
models/detectors/yolox/yolox_pafpn.py

@@ -2,7 +2,10 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from .yolox_basic import (Conv, build_reduce_layer, build_downsample_layer, build_fpn_block)
+try:
+    from .yolox_basic import (Conv, build_reduce_layer, build_downsample_layer, build_fpn_block)
+except:
+    from  yolox_basic import (Conv, build_reduce_layer, build_downsample_layer, build_fpn_block)
 
 
 # YOLO-Style PaFPN
@@ -89,4 +92,34 @@ def build_fpn(cfg, in_dims, out_dim=None):
     if model == 'yolox_pafpn':
         fpn_net = YoloxPaFPN(cfg, in_dims, out_dim)
 
-    return fpn_net
+    return fpn_net
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'fpn': 'yolox_pafpn',
+        'fpn_reduce_layer': 'conv',
+        'fpn_downsample_layer': 'conv',
+        'fpn_core_block': 'cspblock',
+        'fpn_act': 'silu',
+        'fpn_norm': 'BN',
+        'fpn_depthwise': False,
+        'width': 1.0,
+        'depth': 1.0,
+    }
+    model = build_fpn(cfg, in_dims=[256, 512, 1024], out_dim=256)
+    pyramid_feats = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
+    t0 = time.time()
+    outputs = model(pyramid_feats)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(pyramid_feats, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))