yjh0410 hace 9 meses
padre
commit
a6adf35ca2
Se han modificado 2 ficheros con 21 adiciones y 15 borrados
  1. 18 14
      yolo/models/yolov8/yolov8_backbone.py
  2. 3 1
      yolo/models/yolov8/yolov8_pafpn.py

+ 18 - 14
yolo/models/yolov8/yolov8_backbone.py

@@ -123,27 +123,31 @@ class Yolov8Backbone(nn.Module):
 if __name__ == '__main__':
     import time
     from thop import profile
+
+    # YOLOv8 config
     class BaseConfig(object):
         def __init__(self) -> None:
-            self.use_pretrained = True
-            self.width = 0.25
+            self.use_pretrained = False
+            self.width = 0.50
             self.depth = 0.34
-            self.ratio = 2.0
-            self.model_scale = "n"
-
+            self.ratio = 2.00
+            self.model_scale = "s"
     cfg = BaseConfig()
+
+    # Build backbone
     model = Yolov8Backbone(cfg)
-    x = torch.randn(1, 3, 640, 640)
-    t0 = time.time()
+
+    # Randomly generate a input data
+    x = torch.randn(2, 3, 640, 640)
+
+    # Inference
     outputs = model(x)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
+    print(' - the shape of input :  ', x.shape)
     for out in outputs:
-        print(out.shape)
+        print(' - the shape of output : ', out.shape)
 
     x = torch.randn(1, 3, 640, 640)
-    print('==============================')
     flops, params = profile(model, inputs=(x, ), verbose=False)
-    print('==============================')
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))
+    print('============== FLOPs & Params ================')
+    print(' - FLOPs  : {:.2f} G'.format(flops / 1e9 * 2))
+    print(' - Params : {:.2f} M'.format(params / 1e6))

+ 3 - 1
yolo/models/yolov8/yolov8_pafpn.py

@@ -15,7 +15,9 @@ class Yolov8PaFPN(nn.Module):
         super(Yolov8PaFPN, self).__init__()
         # --------------------------- Basic Parameters ---------------------------
         self.in_dims = in_dims[::-1]
-        self.out_dims = [round(256*cfg.width), round(512*cfg.width), round(512*cfg.width*cfg.ratio)]
+        self.out_dims = [round(256*cfg.width),
+                         round(512*cfg.width),
+                         round(512*cfg.width*cfg.ratio)]
 
         # ----------------------------- Yolov8's Top-down FPN -----------------------------
         ## P5 -> P4