yjh0410 1 anno fa
parent
commit
1c2b71e026
2 ha cambiato i file con 9 aggiunte e 2 eliminazioni
  1. 1 1
      models/__init__.py
  2. 8 1
      models/yolov8/yolov8.py

+ 1 - 1
models/__init__.py

@@ -29,7 +29,7 @@ def build_model(args, cfg, is_val=False):
     ## YOLOX
     elif 'yolox' in args.model:
         model, criterion = build_yolox(cfg, is_val)
-    ## YOLOv7
+    ## Modified Anchor-free YOLOv7
     elif 'yolov7' in args.model:
         model, criterion = build_yolov7(cfg, is_val)
     ## YOLOv8

+ 8 - 1
models/yolov8/yolov8.py

@@ -30,10 +30,17 @@ class Yolov8(nn.Module):
         self.no_multi_labels  = False if is_val else True
         
         # ---------------------- Network Parameters ----------------------
+        ## Backbone
         self.backbone = Yolov8Backbone(cfg)
-        self.neck     = SPPF(cfg, self.backbone.feat_dims[-1], self.backbone.feat_dims[-1])
+        self.pyramid_feat_dims = self.backbone.feat_dims[-3:]
+        ## Neck
+        self.neck     = SPPF(cfg, self.pyramid_feat_dims[-1], self.pyramid_feat_dims[-1])
+        self.pyramid_feat_dims[-1] = self.neck.out_dim
+        ## Neck: PaFPN
         self.fpn      = Yolov8PaFPN(cfg, self.backbone.feat_dims)
+        ## Head
         self.head     = Yolov8DetHead(cfg, self.fpn.out_dims)
+        ## Pred
         self.pred     = Yolov8DetPredLayer(cfg, self.head.cls_head_dim, self.head.reg_head_dim)
 
     def post_process(self, cls_preds, box_preds):