فهرست منبع

modify yolov7

yjh0410 1 سال پیش
والد
کامیت
c8c1d7895f

+ 24 - 2
config/yolov7_config.py

@@ -4,6 +4,8 @@
 def build_yolov7_config(args):
     if args.model == 'yolov7_s':
         return Yolov7SConfig()
+    elif args.model == 'yolov7_l':
+        return Yolov7LConfig()
     else:
         raise NotImplementedError("No config for model: {}".format(args.model))
     
@@ -12,7 +14,6 @@ class Yolov7BaseConfig(object):
     def __init__(self) -> None:
         # ---------------- Model config ----------------
         self.width    = 1.0
-        self.depth    = 1.0
         self.reg_max  = 16
         self.out_stride = [8, 16, 32]
         self.max_stride = 32
@@ -33,6 +34,9 @@ class Yolov7BaseConfig(object):
         self.fpn_act  = 'silu'
         self.fpn_norm = 'BN'
         self.fpn_depthwise = False
+        self.fpn_expansions = [0.5, 0.5]
+        self.fpn_block_bw = 4
+        self.fpn_block_dw = 1
         ## Head
         self.head_act  = 'silu'
         self.head_norm = 'BN'
@@ -121,10 +125,28 @@ class Yolov7SConfig(Yolov7BaseConfig):
         super().__init__()
         # ---------------- Model config ----------------
         self.width = 0.50
-        self.depth = 0.34
         self.scale = "s"
+        self.fpn_expansions = [0.5, 1.0]
+        self.fpn_block_bw = 2
+        self.fpn_block_dw = 1
 
         # ---------------- Data process config ----------------
         self.mosaic_prob = 1.0
         self.mixup_prob  = 0.0
         self.copy_paste  = 0.0
+
+# YOLOv7-L
+class Yolov7LConfig(Yolov7BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 1.0
+        self.scale = "l"
+        self.fpn_expansions = [0.5, 0.5]
+        self.fpn_block_bw = 4
+        self.fpn_block_dw = 1
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.15
+        self.copy_paste  = 0.0

+ 0 - 1
config/yolov8_config.py

@@ -38,7 +38,6 @@ class Yolov8BaseConfig(object):
         self.head_act  = 'silu'
         self.head_norm = 'BN'
         self.head_depthwise = False
-        self.head_dim       = 256
         self.num_cls_head   = 2
         self.num_reg_head   = 2
 

+ 6 - 4
models/yolov7/yolov7_backbone.py

@@ -14,10 +14,12 @@ class Yolov7Backbone(nn.Module):
         # ---------------- Basic parameters ----------------
         self.model_scale = cfg.scale
         if self.model_scale in ["l", "x"]:
+            self.elan_depth = 2
             self.feat_dims = [round(64   * cfg.width), round(128  * cfg.width), round(256  * cfg.width),
                               round(512  * cfg.width), round(1024 * cfg.width), round(1024 * cfg.width)]
             self.last_stage_eratio = 0.25
         if self.model_scale in ["n", "s"]:
+            self.elan_depth = 1
             self.feat_dims = [round(64   * cfg.width), round(64  * cfg.width), round(128  * cfg.width),
                               round(256  * cfg.width), round(512 * cfg.width), round(1024 * cfg.width)]
             self.last_stage_eratio = 0.5
@@ -33,28 +35,28 @@ class Yolov7Backbone(nn.Module):
                       kernel_size=3, padding=1, stride=2,
                       act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),      
             ELANLayer(self.feat_dims[1], self.feat_dims[2],
-                      expansion=0.5, num_blocks=round(3*cfg.depth),
+                      expansion=0.5, num_blocks=self.elan_depth,
                       act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),      
         )
         self.layer_3 = nn.Sequential(
             MDown(self.feat_dims[2], self.feat_dims[2],
                   act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),             
             ELANLayer(self.feat_dims[2], self.feat_dims[3],
-                      expansion=0.5, num_blocks=round(3*cfg.depth),
+                      expansion=0.5, num_blocks=self.elan_depth,
                       act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),      
         )
         self.layer_4 = nn.Sequential(
             MDown(self.feat_dims[3], self.feat_dims[3],
                   act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),             
             ELANLayer(self.feat_dims[3], self.feat_dims[4],
-                      expansion=0.5, num_blocks=round(3*cfg.depth),
+                      expansion=0.5, num_blocks=self.elan_depth,
                       act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),      
         )
         self.layer_5 = nn.Sequential(
             MDown(self.feat_dims[4], self.feat_dims[4],
                   act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),             
             ELANLayer(self.feat_dims[4], self.feat_dims[5],
-                      expansion=self.last_stage_eratio, num_blocks=round(3*cfg.depth),
+                      expansion=self.last_stage_eratio, num_blocks=self.elan_depth,
                       act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),      
         )
 

+ 52 - 0
models/yolov7/yolov7_basic.py

@@ -136,3 +136,55 @@ class ELANLayer(nn.Module):
         out = self.conv_layer_3(torch.cat([x1, x2, x3, x4], dim=1))
 
         return out
+
+## PaFPN's ELAN-Block proposed by YOLOv7
+class ELANLayerFPN(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expansions   :List = [0.5, 0.5],
+                 branch_width :int  = 4,
+                 branch_depth :int  = 1,
+                 act_type     :str  = 'silu',
+                 norm_type    :str  = 'BN',
+                 depthwise=False):
+        super(ELANLayerFPN, self).__init__()
+        # Basic parameters
+        inter_dim  = round(in_dim * expansions[0])
+        inter_dim2 = round(inter_dim * expansions[1]) 
+        # Network structure
+        self.cv1 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.cv3 = nn.ModuleList()
+        for idx in range(round(branch_width)):
+            if idx == 0:
+                cvs = [BasicConv(inter_dim, inter_dim2,
+                                 kernel_size=3, padding=1,
+                                 act_type=act_type, norm_type=norm_type, depthwise=depthwise)]
+            else:
+                cvs = [BasicConv(inter_dim2, inter_dim2,
+                                 kernel_size=3, padding=1,
+                                 act_type=act_type, norm_type=norm_type, depthwise=depthwise)]
+            # deeper
+            if round(branch_depth) > 1:
+                for _ in range(1, round(branch_depth)):
+                    cvs.append(BasicConv(inter_dim2, inter_dim2, kernel_size=3, padding=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise))
+                self.cv3.append(nn.Sequential(*cvs))
+            else:
+                self.cv3.append(cvs[0])
+
+        self.output_proj = BasicConv(inter_dim*2+inter_dim2*len(self.cv3), out_dim,
+                                     kernel_size=1, act_type=act_type, norm_type=norm_type)
+
+
+    def forward(self, x):
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        inter_outs = [x1, x2]
+        for m in self.cv3:
+            y1 = inter_outs[-1]
+            y2 = m(y1)
+            inter_outs.append(y2)
+        out = self.output_proj(torch.cat(inter_outs, dim=1))
+
+        return out

+ 1 - 1
models/yolov7/yolov7_head.py

@@ -95,7 +95,7 @@ class Yolov7DetHead(nn.Module):
         ## ----------- Network Parameters -----------
         self.multi_level_heads = nn.ModuleList(
             [DetHead(in_dim       = in_dims[level],
-                     cls_head_dim = max(in_dims[0], min(cfg.num_classes, 100)),
+                     cls_head_dim = max(in_dims[0], min(cfg.num_classes, 128)),
                      reg_head_dim = max(in_dims[0]//4, 16, 4*cfg.reg_max),
                      num_cls_head = cfg.num_cls_head,
                      num_reg_head = cfg.num_reg_head,

+ 39 - 48
models/yolov7/yolov7_pafpn.py

@@ -3,7 +3,7 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from .yolov7_basic import BasicConv, ELANLayer, MDown
+from .yolov7_basic import BasicConv, ELANLayerFPN, MDown
 
 
 # PaFPN-ELAN (YOLOv7's)
@@ -12,6 +12,7 @@ class Yolov7PaFPN(nn.Module):
         super(Yolov7PaFPN, self).__init__()
         # ----------------------------- Basic parameters -----------------------------
         self.in_dims = in_dims
+        self.out_dims = [round(256*cfg.width), round(512*cfg.width), round(1024*cfg.width)]
         c3, c4, c5 = in_dims
 
         # ----------------------------- Top-down FPN -----------------------------
@@ -20,50 +21,54 @@ class Yolov7PaFPN(nn.Module):
                                         kernel_size=1, act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
         self.reduce_layer_2 = BasicConv(c4, round(256*cfg.width),
                                         kernel_size=1, act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
-        self.top_down_layer_1 = ELANLayer(in_dim     = round(256*cfg.width) + round(256*cfg.width),
-                                          out_dim    = round(256*cfg.width),
-                                          expansion  = 0.5,
-                                          num_blocks = round(3*cfg.depth),
-                                          act_type   = cfg.fpn_act,
-                                          norm_type  = cfg.fpn_norm,
-                                          depthwise  = cfg.fpn_depthwise,
-                                          )
+        self.top_down_layer_1 = ELANLayerFPN(in_dim     = round(256*cfg.width) + round(256*cfg.width),
+                                             out_dim    = round(256*cfg.width),
+                                             expansions   = cfg.fpn_expansions,
+                                             branch_width = cfg.fpn_block_bw,
+                                             branch_depth = cfg.fpn_block_dw,
+                                             act_type   = cfg.fpn_act,
+                                             norm_type  = cfg.fpn_norm,
+                                             depthwise  = cfg.fpn_depthwise,
+                                             )
         ## P4 -> P3
         self.reduce_layer_3 = BasicConv(round(256*cfg.width), round(128*cfg.width),
                                         kernel_size=1, act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
         self.reduce_layer_4 = BasicConv(c3, round(128*cfg.width),
                                         kernel_size=1, act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
-        self.top_down_layer_2 = ELANLayer(in_dim     = round(128*cfg.width) + round(128*cfg.width),
-                                          out_dim    = round(128*cfg.width),
-                                          expansion  = 0.5,
-                                          num_blocks = round(3*cfg.depth),
-                                          act_type   = cfg.fpn_act,
-                                          norm_type  = cfg.fpn_norm,
-                                          depthwise  = cfg.fpn_depthwise,
-                                          )
+        self.top_down_layer_2 = ELANLayerFPN(in_dim     = round(128*cfg.width) + round(128*cfg.width),
+                                             out_dim    = round(128*cfg.width),
+                                             expansions   = cfg.fpn_expansions,
+                                             branch_width = cfg.fpn_block_bw,
+                                             branch_depth = cfg.fpn_block_dw,
+                                             act_type   = cfg.fpn_act,
+                                             norm_type  = cfg.fpn_norm,
+                                             depthwise  = cfg.fpn_depthwise,
+                                             )
         # ----------------------------- Bottom-up FPN -----------------------------
         ## P3 -> P4
         self.downsample_layer_1 = MDown(round(128*cfg.width), round(256*cfg.width),
                                         act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
-        self.bottom_up_layer_1 = ELANLayer(in_dim     = round(256*cfg.width) + round(256*cfg.width),
-                                           out_dim    = round(256*cfg.width),
-                                           expansion  = 0.5,
-                                           num_blocks = round(3*cfg.depth),
-                                           act_type   = cfg.fpn_act,
-                                           norm_type  = cfg.fpn_norm,
-                                           depthwise  = cfg.fpn_depthwise,
-                                           )
+        self.bottom_up_layer_1 = ELANLayerFPN(in_dim     = round(256*cfg.width) + round(256*cfg.width),
+                                              out_dim    = round(256*cfg.width),
+                                              expansions   = cfg.fpn_expansions,
+                                              branch_width = cfg.fpn_block_bw,
+                                              branch_depth = cfg.fpn_block_dw,
+                                              act_type     = cfg.fpn_act,
+                                              norm_type    = cfg.fpn_norm,
+                                              depthwise    = cfg.fpn_depthwise,
+                                              )
         ## P4 -> P5
         self.downsample_layer_2 = MDown(round(256*cfg.width), round(512*cfg.width),
                                         act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
-        self.bottom_up_layer_2 = ELANLayer(in_dim     = round(512*cfg.width) + c5,
-                                           out_dim    = round(512*cfg.width),
-                                           expansion  = 0.5,
-                                           num_blocks = round(3*cfg.depth),
-                                           act_type   = cfg.fpn_act,
-                                           norm_type  = cfg.fpn_norm,
-                                           depthwise  = cfg.fpn_depthwise,
-                                           )
+        self.bottom_up_layer_2 = ELANLayerFPN(in_dim     = round(512*cfg.width) + c5,
+                                              out_dim    = round(512*cfg.width),
+                                              expansions   = cfg.fpn_expansions,
+                                              branch_width = cfg.fpn_block_bw,
+                                              branch_depth = cfg.fpn_block_dw,
+                                              act_type   = cfg.fpn_act,
+                                              norm_type  = cfg.fpn_norm,
+                                              depthwise  = cfg.fpn_depthwise,
+                                              )
 
         # ----------------------------- Head conv layers -----------------------------
         ## Head convs
@@ -77,15 +82,6 @@ class Yolov7PaFPN(nn.Module):
                                      kernel_size=3, padding=1, stride=1,
                                      act_type=cfg.fpn_act, norm_type=cfg.fpn_norm, depthwise=cfg.fpn_depthwise)
 
-        # ---------------------- Yolov5's output projection ----------------------
-        self.out_layers = nn.ModuleList([
-            BasicConv(in_dim, round(cfg.head_dim*cfg.width), kernel_size=1,
-                      act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
-                      for in_dim in [round(256*cfg.width), round(512*cfg.width), round(1024*cfg.width)]
-                      ])
-        self.out_dims = [round(cfg.head_dim*cfg.width)] * 3
-
-
     def forward(self, features):
         c3, c4, c5 = features
 
@@ -112,10 +108,5 @@ class Yolov7PaFPN(nn.Module):
         p5 = self.bottom_up_layer_2(p5)
 
         out_feats = [self.head_conv_1(p3), self.head_conv_2(p4), self.head_conv_3(p5)]
-
-        # output proj layers
-        out_feats_proj = []
-        for feat, layer in zip(out_feats, self.out_layers):
-            out_feats_proj.append(layer(feat))
             
-        return out_feats_proj
+        return out_feats

+ 1 - 1
models/yolov8/yolov8_head.py

@@ -95,7 +95,7 @@ class Yolov8DetHead(nn.Module):
         ## ----------- Network Parameters -----------
         self.multi_level_heads = nn.ModuleList(
             [DetHead(in_dim       = in_dims[level],
-                     cls_head_dim = max(in_dims[0], min(cfg.num_classes, 100)),
+                     cls_head_dim = max(in_dims[0], min(cfg.num_classes, 128)),
                      reg_head_dim = max(in_dims[0]//4, 16, 4*cfg.reg_max),
                      num_cls_head = cfg.num_cls_head,
                      num_reg_head = cfg.num_reg_head,

+ 3 - 12
models/yolov8/yolov8_pafpn.py

@@ -17,7 +17,7 @@ class Yolov8PaFPN(nn.Module):
         print('FPN: {}'.format("Yolo PaFPN"))
         # --------------------------- Basic Parameters ---------------------------
         self.in_dims = in_dims[::-1]
-        self.out_dims = [round(cfg.head_dim * cfg.width)] * 3
+        self.out_dims = [round(256*cfg.width), round(512*cfg.width), round(512*cfg.width*cfg.ratio)]
 
         # ---------------- Top dwon ----------------
         ## P5 -> P4
@@ -67,10 +67,6 @@ class Yolov8PaFPN(nn.Module):
                                            norm_type  = cfg.fpn_norm,
                                            depthwise  = cfg.fpn_depthwise,
                                            )
-        self.out_layers = nn.ModuleList([
-            BasicConv(feat_dim, self.out_dims[i], kernel_size=1, act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
-            for i, feat_dim in enumerate([round(256*cfg.width), round(512*cfg.width), round(512*cfg.width*cfg.ratio)])
-            ])
 
         self.init_weights()
         
@@ -104,10 +100,5 @@ class Yolov8PaFPN(nn.Module):
         p5 = self.bottom_up_layer_2(torch.cat([p4_ds, c5], dim=1))
 
         out_feats = [p3, p4, p5] # [P3, P4, P5]
-        
-        # output proj layers
-        out_feats_proj = []
-        for feat, layer in zip(out_feats, self.out_layers):
-            out_feats_proj.append(layer(feat))
-        
-        return out_feats_proj
+                
+        return out_feats