yjh0410 2 years ago
parent
commit
4a1b20dacd

+ 11 - 7
models/detectors/yolov8/yolov8.py

@@ -29,16 +29,17 @@ class YOLOv8(nn.Module):
         # ---------------------- Basic Parameters ----------------------
         self.cfg = cfg
         self.device = device
-        self.stride = cfg['stride']
+        self.strides = cfg['stride']
         self.reg_max = cfg['reg_max']
         self.num_classes = num_classes
         self.trainable = trainable
         self.conf_thresh = conf_thresh
         self.nms_thresh = nms_thresh
+        self.num_levels = len(self.strides)
+        self.num_classes = num_classes
         self.topk = topk
         self.deploy = deploy
         self.nms_class_agnostic = nms_class_agnostic
-        self.head_dim = round(256*cfg['width'])
         
         # ---------------------- Network Parameters ----------------------
         ## ----------- Backbone -----------
@@ -53,13 +54,16 @@ class YOLOv8(nn.Module):
         self.fpn_dims = self.fpn.out_dim
 
         ## ----------- Heads -----------
-        self.det_heads = build_det_head(
-            cfg, self.fpn_dims, self.head_dim, 4 * self.reg_max, num_levels=len(self.stride))
+        self.det_heads = build_det_head(cfg, self.fpn_dims, self.num_levels, num_classes, self.reg_max)
 
         ## ----------- Preds -----------
-        self.pred_layers = build_pred_layer(
-            self.det_heads.cls_head_dim, self.det_heads.reg_head_dim, self.stride,
-            num_classes=num_classes, num_coords=4, num_levels=len(self.stride), reg_max=self.reg_max)
+        self.pred_layers = build_pred_layer(cls_dim     = self.det_heads.cls_head_dim,
+                                            reg_dim     = self.det_heads.reg_head_dim,
+                                            strides     = self.strides,
+                                            num_classes = num_classes,
+                                            num_coords  = 4,
+                                            num_levels  = self.num_levels,
+                                            reg_max     = self.reg_max)
 
     ## post-process
     def post_process(self, cls_preds, box_preds):

+ 14 - 14
models/detectors/yolov8/yolov8_basic.py

@@ -109,29 +109,29 @@ class Yolov8StageBlock(nn.Module):
     def __init__(self,
                  in_dim,
                  out_dim,
-                 expand_ratio = 0.5,
-                 num_blocks   = 1,
-                 shortcut     = False,
-                 act_type     = 'silu',
-                 norm_type    = 'BN',
-                 depthwise    = False,):
+                 num_blocks = 1,
+                 shortcut   = False,
+                 act_type   = 'silu',
+                 norm_type  = 'BN',
+                 depthwise  = False,):
         super(Yolov8StageBlock, self).__init__()
-        inter_dim = int(out_dim * expand_ratio)
-        self.cv1 = Conv(in_dim, inter_dim, k=1, norm_type=norm_type, act_type=act_type)
-        self.cv2 = Conv(in_dim, inter_dim, k=1, norm_type=norm_type, act_type=act_type)
+        self.inter_dim = out_dim // 2
+        self.input_proj = Conv(in_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type)
         self.m = nn.Sequential(*(
-            Yolov8Bottleneck(inter_dim, inter_dim, 1.0, [3, 3], shortcut, act_type, norm_type, depthwise)
+            Yolov8Bottleneck(self.inter_dim, self.inter_dim, 1.0, [3, 3], shortcut, act_type, norm_type, depthwise)
             for _ in range(num_blocks)))
-        self.cv3 = Conv((2 + num_blocks) * inter_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.output_proj = Conv((2 + num_blocks) * self.inter_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type)
 
     def forward(self, x):
-        x1 = self.cv1(x)
-        x2 = self.cv2(x)
+        # Input proj
+        x1, x2 = torch.split(self.input_proj(x), self.inter_dim, dim=1)
         out = list([x1, x2])
 
+        # Bottlenecl
         out.extend(m(out[-1]) for m in self.m)
 
-        out = self.cv3(torch.cat(out, dim=1))
+        # Output proj
+        out = self.output_proj(torch.cat(out, dim=1))
 
         return out
     

+ 6 - 6
models/detectors/yolov8/yolov8_head.py

@@ -72,14 +72,14 @@ class SingleLevelHead(nn.Module):
 
 # Multi-level Head
 class MultiLevelHead(nn.Module):
-    def __init__(self, cfg, in_dims, cls_out_dim, reg_out_dim, num_levels=3):
+    def __init__(self, cfg, in_dims, num_levels=3, num_classes=80, reg_max=16):
         super().__init__()
         ## ----------- Network Parameters -----------
         self.multi_level_heads = nn.ModuleList(
             [SingleLevelHead(
                 in_dims[level],
-                cls_out_dim,            # cls head dim
-                reg_out_dim,            # reg head dim
+                max(in_dims[0], num_classes),        # cls head out_dim
+                max(in_dims[0]//4, 16, 4*reg_max),   # reg head out_dim
                 cfg['num_cls_head'],
                 cfg['num_reg_head'],
                 cfg['head_act'],
@@ -111,9 +111,9 @@ class MultiLevelHead(nn.Module):
     
 
 # build detection head
-def build_det_head(cfg, in_dims, cls_out_dim, reg_out_dim, num_levels=3):
+def build_det_head(cfg, in_dims, num_levels=3, num_classes=80, reg_max=16):
     if cfg['head'] == 'decoupled_head':
-        head = MultiLevelHead(cfg, in_dims, cls_out_dim, reg_out_dim, num_levels) 
+        head = MultiLevelHead(cfg, in_dims, num_levels, num_classes, reg_max)
 
     return head
 
@@ -134,7 +134,7 @@ if __name__ == '__main__':
     cls_out_dim = 256
     reg_out_dim = 64
     # Head-1
-    model = build_det_head(cfg, fpn_dims, cls_out_dim, reg_out_dim, num_levels=3)
+    model = build_det_head(cfg, fpn_dims, num_levels=3, num_classes=80, reg_max=16)
     print(model)
     fpn_feats = [torch.randn(1, fpn_dims[0], 80, 80), torch.randn(1, fpn_dims[1], 40, 40), torch.randn(1, fpn_dims[2], 20, 20)]
     t0 = time.time()

+ 0 - 4
models/detectors/yolov8/yolov8_pafpn.py

@@ -31,7 +31,6 @@ class Yolov8PaFPN(nn.Module):
         ## P5 -> P4
         self.top_down_layer_1 = Yolov8StageBlock(in_dim       = c5 + c4,
                                                  out_dim      = round(512*width),
-                                                 expand_ratio = 0.5,
                                                  num_blocks   = round(3*depth),
                                                  shortcut     = False,
                                                  act_type     = act_type,
@@ -41,7 +40,6 @@ class Yolov8PaFPN(nn.Module):
         ## P4 -> P3
         self.top_down_layer_2 = Yolov8StageBlock(in_dim       = round(512*width) + c3,
                                                  out_dim      = round(256*width),
-                                                 expand_ratio = 0.5,
                                                  num_blocks   = round(3*depth),
                                                  shortcut     = False,
                                                  act_type     = act_type,
@@ -53,7 +51,6 @@ class Yolov8PaFPN(nn.Module):
         self.dowmsample_layer_1 = Conv(round(256*width), round(256*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
         self.bottom_up_layer_1 = Yolov8StageBlock(in_dim       = round(256*width) + round(512*width),
                                                   out_dim      = round(512*width),
-                                                  expand_ratio = 0.5,
                                                   num_blocks   = round(3*depth),
                                                   shortcut     = False,
                                                   act_type     = act_type,
@@ -64,7 +61,6 @@ class Yolov8PaFPN(nn.Module):
         self.dowmsample_layer_2 = Conv(round(512*width), round(512*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
         self.bottom_up_layer_2 = Yolov8StageBlock(in_dim       = round(512 * width) + c5,
                                                   out_dim      = round(512 * width * ratio),
-                                                  expand_ratio = 0.5,
                                                   num_blocks   = round(3*depth),
                                                   shortcut     = False,
                                                   act_type     = act_type,