فهرست منبع

modify RTCDet-v2

yjh0410 2 سال پیش
والد
کامیت
b83f8dd9a9

+ 2 - 0
config/model_config/rtcdet_v2_config.py

@@ -26,6 +26,7 @@ rtcdet_v2_cfg = {
         'fpn_reduce_layer': 'conv',
         'fpn_downsample_layer': 'conv',
         'fpn_core_block': 'elan_block',
+        'fpn_branch_depth': 3,
         'fpn_expand_ratio': 0.25,
         'fpn_act': 'silu',
         'fpn_norm': 'BN',
@@ -82,6 +83,7 @@ rtcdet_v2_cfg = {
         'fpn_reduce_layer': 'conv',
         'fpn_downsample_layer': 'conv',
         'fpn_core_block': 'elan_block',
+        'fpn_branch_depth': 3,
         'fpn_expand_ratio': 0.25,
         'fpn_act': 'silu',
         'fpn_norm': 'BN',

+ 3 - 3
models/detectors/rtcdet_v1/rtcdet_v1_backbone.py

@@ -134,9 +134,9 @@ if __name__ == '__main__':
         'pretrained': True,
         'bk_act': 'silu',
         'bk_norm': 'BN',
-        'bk_dpw': True,
-        'width': 0.25,
-        'depth': 0.34,
+        'bk_dpw': False,
+        'width': 1.0,
+        'depth': 1.0,
     }
     model, feats = build_backbone(cfg)
     x = torch.randn(1, 3, 640, 640)

+ 11 - 11
models/detectors/rtcdet_v2/rtcdet_v2_backbone.py

@@ -1,9 +1,9 @@
 import torch
 import torch.nn as nn
 try:
-    from .rtcdet_v2_basic import Conv, ELAN_Stage, DSBlock
+    from .rtcdet_v2_basic import Conv, ELANBlock, DSBlock
 except:
-    from rtcdet_v2_basic import Conv, ELAN_Stage, DSBlock
+    from rtcdet_v2_basic import Conv, ELANBlock, DSBlock
 
 
 model_urls = {
@@ -26,9 +26,9 @@ class ELANNetv2(nn.Module):
         self.width = width
         self.depth = depth
         self.expand_ratio = [0.5, 0.5, 0.5, 0.25]
+        self.branch_depths = [round(dep * depth) for dep in [3, 3, 3, 3]]
         ## pyramid feats
         self.feat_dims = [round(dim * width) for dim in [64, 128, 256, 512, 1024, 1024]]
-        self.branch_depths = [round(dep * depth) for dep in [3, 3, 3, 3]]
         ## nonlinear
         self.act_type = act_type
         self.norm_type = norm_type
@@ -42,23 +42,23 @@ class ELANNetv2(nn.Module):
         )
         ## P2/4
         self.layer_2 = nn.Sequential(   
-            DSBlock(self.feat_dims[0], self.feat_dims[1], act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise),
-            ELAN_Stage(self.feat_dims[1], self.feat_dims[2], self.expand_ratio[0], self.branch_depths[0], True, self.act_type, self.norm_type, self.depthwise)
+            Conv(self.feat_dims[0], self.feat_dims[1], k=3, p=1, s=2, act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise),
+            ELANBlock(self.feat_dims[1], self.feat_dims[2], self.expand_ratio[0], self.branch_depths[0], True, self.act_type, self.norm_type, self.depthwise)
         )
         ## P3/8
         self.layer_3 = nn.Sequential(
-            Conv(self.feat_dims[2], self.feat_dims[2], k=3, p=1, s=2, act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise),
-            ELAN_Stage(self.feat_dims[2], self.feat_dims[3], self.expand_ratio[1], self.branch_depths[1], True, self.act_type, self.norm_type, self.depthwise)
+            DSBlock(self.feat_dims[2], self.feat_dims[2], act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise),
+            ELANBlock(self.feat_dims[2], self.feat_dims[3], self.expand_ratio[1], self.branch_depths[1], True, self.act_type, self.norm_type, self.depthwise)
         )
         ## P4/16
         self.layer_4 = nn.Sequential(
-            Conv(self.feat_dims[3], self.feat_dims[3], k=3, p=1, s=2, act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise),
-            ELAN_Stage(self.feat_dims[3], self.feat_dims[4], self.expand_ratio[2], self.branch_depths[2], True, self.act_type, self.norm_type, self.depthwise)
+            DSBlock(self.feat_dims[3], self.feat_dims[3], act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise),
+            ELANBlock(self.feat_dims[3], self.feat_dims[4], self.expand_ratio[2], self.branch_depths[2], True, self.act_type, self.norm_type, self.depthwise)
         )
         ## P5/32
         self.layer_5 = nn.Sequential(
-            Conv(self.feat_dims[4], self.feat_dims[4], k=3, p=1, s=2, act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise),
-            ELAN_Stage(self.feat_dims[4], self.feat_dims[5], self.expand_ratio[3], self.branch_depths[3], True, self.act_type, self.norm_type, self.depthwise)
+            DSBlock(self.feat_dims[4], self.feat_dims[4], act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise),
+            ELANBlock(self.feat_dims[4], self.feat_dims[5], self.expand_ratio[3], self.branch_depths[3], True, self.act_type, self.norm_type, self.depthwise)
         )
 
 

+ 32 - 33
models/detectors/rtcdet_v2/rtcdet_v2_basic.py

@@ -180,8 +180,8 @@ class YoloBottleneck(nn.Module):
 
 
 # ---------------------------- Base Modules ----------------------------
-## ELAN Stage of Backbone
-class ELAN_Stage(nn.Module):
+## ELAN Block
+class ELANBlock(nn.Module):
     def __init__(self, in_dim, out_dim, expand_ratio :float=0.5, branch_depth :int=1, shortcut=False, act_type='silu', norm_type='BN', depthwise=False):
         super().__init__()
         # ----------- Basic Parameters -----------
@@ -190,26 +190,28 @@ class ELAN_Stage(nn.Module):
         self.inter_dim = round(in_dim * expand_ratio)
         self.expand_ratio = expand_ratio
         self.branch_depth = branch_depth
+        self.shortcut = shortcut
         # ----------- Network Parameters -----------
         self.cv1 = Conv(in_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
         self.cv2 = Conv(in_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
         self.cv3 = nn.Sequential(*[
-            YoloBottleneck(self.inter_dim, self.inter_dim, [1, 3], 1.0, shortcut, act_type, norm_type, depthwise)
+            Conv(self.inter_dim, self.inter_dim, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
             for _ in range(branch_depth)
         ])
         self.cv4 = nn.Sequential(*[
-            YoloBottleneck(self.inter_dim, self.inter_dim, [1, 3], 1.0, shortcut, act_type, norm_type, depthwise)
+            Conv(self.inter_dim, self.inter_dim, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
             for _ in range(branch_depth)
         ])
-        ## output
-        self.out_conv = Conv(self.inter_dim*4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.out = Conv(self.inter_dim*4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
 
     def forward(self, x):
         x1 = self.cv1(x)
         x2 = self.cv2(x)
-        x3 = self.cv3(x2)
-        x4 = self.cv4(x3)
-        out = self.out_conv(torch.cat([x1, x2, x3, x4], dim=1))
+        x3 = self.cv3(x2) + x2 if self.shortcut else self.cv3(x2)
+        x4 = self.cv4(x3) + x3 if self.shortcut else self.cv4(x3)
+
+        # [B, C, H, W] -> [B, 2C, H, W]
+        out = self.out(torch.cat([x1, x2, x3, x4], dim=1))
 
         return out
     
@@ -217,23 +219,20 @@ class ELAN_Stage(nn.Module):
 class DSBlock(nn.Module):
     def __init__(self, in_dim, out_dim, act_type='silu', norm_type='BN', depthwise=False):
         super().__init__()
-        self.in_dim = in_dim
-        self.out_dim = out_dim
-        # branch-1
-        self.maxpool = nn.MaxPool2d((2, 2), 2)
-        # branch-2
-        self.ds_conv = Conv(in_dim, in_dim, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        # output
-        self.out_conv = Conv(in_dim*2, out_dim, k=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-
+        inter_dim = out_dim // 2
+        self.branch_1 = nn.Sequential(
+            nn.MaxPool2d((2, 2), 2),
+            Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        )
+        self.branch_2 = nn.Sequential(
+            Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type),
+            Conv(inter_dim, inter_dim, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
 
     def forward(self, x):
-        # branch-1
-        x1 = self.maxpool(x)
-        # branch-2
-        x2 = self.ds_conv(x)
-        # out-proj
-        out = self.out_conv(torch.cat([x1, x2], dim=1))
+        x1 = self.branch_1(x)
+        x2 = self.branch_2(x)
+        out = torch.cat([x1, x2], dim=1)
 
         return out
 
@@ -242,15 +241,15 @@ class DSBlock(nn.Module):
 ## build fpn's core block
 def build_fpn_block(cfg, in_dim, out_dim):
     if cfg['fpn_core_block'] == 'elan_block':
-        layer = ELAN_Stage(in_dim        = in_dim,
-                           out_dim       = out_dim,
-                           expand_ratio  = cfg['fpn_expand_ratio'],
-                           branch_depth  = round(3 * cfg['depth']),
-                           shortcut      = False,
-                           act_type      = cfg['fpn_act'],
-                           norm_type     = cfg['fpn_norm'],
-                           depthwise     = cfg['fpn_depthwise']
-                           )
+        layer = ELANBlock(in_dim        = in_dim,
+                          out_dim       = out_dim,
+                          expand_ratio  = cfg['fpn_expand_ratio'],
+                          branch_depth  = round(3 * cfg['depth']),
+                          shortcut      = False,
+                          act_type      = cfg['fpn_act'],
+                          norm_type     = cfg['fpn_norm'],
+                          depthwise     = cfg['fpn_depthwise']
+                          )
         
     return layer
 

+ 1 - 0
models/detectors/rtcdet_v2/rtcdet_v2_pafpn.py

@@ -109,6 +109,7 @@ if __name__ == '__main__':
         'fpn_reduce_layer': 'conv',
         'fpn_downsample_layer': 'conv',
         'fpn_core_block': 'elan_block',
+        'fpn_branch_depth': 3,
         'fpn_expand_ratio': 0.25,
         'fpn_act': 'silu',
         'fpn_norm': 'BN',

+ 1 - 0
models/detectors/yolov7/README.md

@@ -4,6 +4,7 @@
 |-------------|---------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
 | YOLOv7-Tiny | ELANNet-Tiny  | 1xb16 |  640  |         38.0           |       56.8        |   22.6            |   7.9              | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov7_tiny_coco.pth) |
 | YOLOv7      | ELANNet-Large | 1xb16 |  640  |         48.0           |       67.5        |   144.6           |   44.0             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov7_large_coco.pth) |
+| YOLOv7-X    | ELANNet-Huge  | 8xb8  |  640  |                        |                   |                   |                    |  |
 
 - For training, we train YOLOv7 and YOLOv7-Tiny with 300 epochs on COCO.
 - For data augmentation, we use the large scale jitter (LSJ), Mosaic augmentation and Mixup augmentation, following the setting of [YOLOv5](https://github.com/ultralytics/yolov5).

+ 3 - 3
models/detectors/yolox/README.md

@@ -2,10 +2,10 @@
 
 |   Model |   Backbone   | Batch | Scale | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
 |---------|--------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
-| YOLOX-N | CSPDarkNet-N | 8xb8  |  640  |         30.4           |       48.9        |   7.5             |   2.3              | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_n_coco.pth) |
 | YOLOX-S | CSPDarkNet-S | 8xb8  |  640  |         39.0           |       58.8        |   26.8            |   8.9              | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_s_coco.pth) |
-| YOLOX-M | CSPDarkNet-M | 8xb8 |  640  |         46.2           |       66.0        |   74.3            |   25.4             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_m_coco.pth) |
-| YOLOX-L | CSPDarkNet-L | 8xb8 |  640  |         48.7           |       68.0        |   155.4           |   54.2             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_l_coco.pth) |
+| YOLOX-M | CSPDarkNet-M | 8xb8  |  640  |         46.2           |       66.0        |   74.3            |   25.4             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_m_coco.pth) |
+| YOLOX-L | CSPDarkNet-L | 8xb8  |  640  |         48.7           |       68.0        |   155.4           |   54.2             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_l_coco.pth) |
+| YOLOX-X | CSPDarkNet-X | 8xb8  |  640  |                        |                   |                   |                    |  |
 
 - For training, we train YOLOX series with 300 epochs on COCO.
 - For data augmentation, we use the large scale jitter (LSJ), Mosaic augmentation and Mixup augmentation.