Przeglądaj źródła

add yolo benchmark

yjh0410 1 rok temu
rodzic
commit
e0a199fcb0
2 zmienionych plików z 172 dodań i 0 usunięć
  1. 53 0
      README.md
  2. 119 0
      yolo/benchmark.py

+ 53 - 0
README.md

@@ -12,6 +12,59 @@
 ## 准备工作
 在使用此代码前,需要读者完成一些必要的环境配置,如python语言的安装、pytorch框架的安装等,随后,遵循`yolo/`和`odlab/`两个文件中的`README.md`文件所提供的内容,配置相关的环境、准备学习所需的数据集,并了解如何使用此项目代码进行训练和测试。如果读者想使用此代码去训练自定义的数据集,也请遵从这两个文件夹中的`README.md`文件中所给出的指示和说明来准备数据,并训练和测试。
 
+## 实验结果
+### YOLO系列
+下面的两个表分别汇报了本项目的YOLO系列的small量级的模型在VOC和COCO数据集上的性能指标,
+
+- VOC
+
+| Model       | Batch | Scale | AP<sup>val<br>0.5 | Weight |  Logs  |
+|-------------|-------|-------|-------------------|--------|--------|
+| YOLOv1-R18  | 1xb16 |  640  |               | [ckpt]() | [log]() |
+| YOLOv2-R18  | 1xb16 |  640  |               | [ckpt]() | [log]() |
+| YOLOv3-S    | 1xb16 |  640  |               | [ckpt]() | [log]() |
+| YOLOv5-S    | 1xb16 |  640  |               | [ckpt]() | [log]() |
+| YOLOv5-AF-S | 1xb16 |  640  |               | [ckpt]() | [log]() |
+| YOLOv8-S    | 1xb16 |  640  |               | [ckpt]() | [log]() |
+| GELAN-S     | 1xb16 |  640  |               | [ckpt]() | [log]() |
+
+- COCO
+
+| Model       | Batch | Scale | FPS<sup>FP32<br>RTX 4060 |AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight | Logs |
+|-------------|-------|-------|--------------------------|------------------------|-------------------|-------------------|--------------------|--------|------|
+| YOLOv1-R18  | 1xb16 |  640  |                          |         27.6           |       46.8        |   37.8            |   21.3             | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/yolov1_r18_coco.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/YOLOv1-R18-COCO.txt) |
+| YOLOv2-R18  | 1xb16 |  640  |                          |         28.4           |       47.4        |   38.0            |   21.5             | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/yolov2_r18_coco.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/YOLOv2-R18-COCO.txt) |
+| YOLOv3-S    | 1xb16 |  640  |                          |         31.3           |        49.2       |   25.2            |   7.3              | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v3/releases/download/yolo_tutorial_ckpt/yolov3_s_coco.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v3/releases/download/yolo_tutorial_ckpt/YOLOv3-S-COCO.txt) |
+| YOLOv5-S    | 1xb16 |  640  |                          |       38.8             |     56.9          |   27.3            |   9.0              | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v5/releases/download/yolo_tutorial_ckpt/yolov5_s_coco.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v5/releases/download/yolo_tutorial_ckpt/YOLOv5-S-COCO.txt) |
+| YOLOv5-AF-S | 1xb16 |  640  |                          |       39.6             |       58.7        |   26.9            |   8.9              | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v5/releases/download/yolo_tutorial_ckpt/yolov5_af_s_coco.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v5/releases/download/yolo_tutorial_ckpt/YOLOv5-AF-S-COCO.txt) |
+| YOLOv8-S    | 1xb16 |  640  |                          |                        |                   |   28.4            |   11.3            |  |  |
+| GELAN-S     | 1xb16 |  640  |                          |                        |                   |   26.9            |   8.9             |  |  |
+
+### RT-DETR系列
+下表汇报了本项目的RT-DETR系列在COCO数据集上的性能指标,
+
+- COCO
+
+| Model        | Batch | Scale | FPS<sup>FP32<br>RTX 4060 |AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight | Logs |
+|--------------|-------|-------|--------------------------|------------------------|-------------------|-------------------|--------------------|--------|------|
+| RT-DETR-R18  | 4xb4  |  640  |                          |           45.5         |        63.5       |        66.8       |        21.0        | [ckpt](https://github.com/yjh0410/ODLab-World/releases/download/coco_weight/rtdetr_r18_coco.pth) | [log](https://github.com/yjh0410/ODLab-World/releases/download/coco_weight/RT-DETR-R18-COCO.txt)|
+| RT-DETR-R50  | 4xb4  |  640  |                          |           50.6         |        69.4       |       112.1       |        36.7        | [ckpt](https://github.com/yjh0410/ODLab-World/releases/download/coco_weight/rtdetr_r50_coco.pth) | [log](https://github.com/yjh0410/ODLab-World/releases/download/coco_weight/RT-DETR-R50-COCO.txt)|
+
+### ODLab系列
+下表汇报了本项目的ODLab系列在COCO数据集上的性能指标,
+
+- COCO
+
+| Model          | Sclae      | FPS<sup>FP32<br>RTX 4060 | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | Weight | Logs |
+|----------------|------------|--------------------------|------------------------|-------------------|--------|------|
+| FCOS_R18_1x    |  800,1333  |           24             |          34.0          |        52.2       | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/fcos_r18_1x_coco.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/FCOS-R18-1x.txt) |
+| FCOS_R50_1x    |  800,1333  |            9             |          39.0          |        58.3       | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/fcos_r50_1x_coco.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/FCOS-R50-1x.txt) |
+| FCOS_RT_R18_3x |  512,736   |           56             |          35.8          |        53.3       | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/fcos_rt_r18_3x_coco.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/FCOS-RT-R18-3x.txt) |
+| FCOS_RT_R50_3x |  512,736   |           34             |          40.7          |        59.3       | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/fcos_rt_r50_3x_coco.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/FCOS-RT-R50-3x.txt) |
+| YOLOF_R18_C5_1x  |  800,1333  |          54          |          32.8          |       51.4        | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/yolof_r18_c5_1x_coco.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/YOLOF-R18-C5-1x.txt) |
+| YOLOF_R50_C5_1x  |  800,1333  |          21          |          37.7          |       57.2        | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/yolof_r50_c5_1x_coco.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/YOLOF-R50-C5-1x.txt) |
+
+
 # The source code of the second edition of the book "YOLO Object Detection"
 This project is the source code of the "YOLO Target Detection" book (second edition), which includes all YOLO models, RT-DETR models, DETR models, FCOS models, and YOLOF models involved in this book. For YOLO and RT-DETR, readers can find all source codes in the `yolo/` folder of the project; for DETR, FCOS and YOLOF models, readers can find all source codes in the `odlab/` folder of the project. 
 

+ 119 - 0
yolo/benchmark.py

@@ -0,0 +1,119 @@
+import argparse
+import cv2
+import os
+import time
+import numpy as np
+from copy import deepcopy
+import torch
+
+# load transform
+from dataset.build import build_dataset, build_transform
+
+# load some utils
+from utils.misc import load_weight, compute_flops
+from utils.box_ops import rescale_bboxes
+from utils.vis_tools import visualize
+
+from config import build_config
+from models import build_model
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
+    # Basic setting
+    parser.add_argument('-size', '--img_size', default=640, type=int,
+                        help='the max size of input image')
+    parser.add_argument('--cuda', action='store_true', default=False, 
+                        help='use cuda.')
+
+    # Model setting
+    parser.add_argument('-m', '--model', default='yolo_n', type=str,
+                        help='build yolo')
+    parser.add_argument('--weight', default=None,
+                        type=str, help='Trained state_dict file path to open')
+    parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
+                        help='fuse Conv & BN')
+    parser.add_argument('--fuse_rep_conv', action='store_true', default=False,
+                        help='fuse Conv & BN')
+
+    # Data setting
+    parser.add_argument('--root', default='D:/python_work/dataset/COCO/',
+                        help='data root')
+
+    return parser.parse_args()
+
+
+@torch.no_grad()
+def test_det(model, 
+             device, 
+             dataset,
+             transform=None
+             ):
+    # Step-1: Compute FLOPs and Params
+    compute_flops(model, cfg.test_img_size, device)
+
+    # Step-2: Compute FPS
+    num_images = 2002
+    total_time = 0
+    count = 0
+    with torch.no_grad():
+        for index in range(num_images):
+            if index % 500 == 0:
+                print('Testing image {:d}/{:d}....'.format(index+1, num_images))
+
+            # Load an image
+            image, _ = dataset.pull_image(index)
+
+            # Preprocess
+            x, _, ratio = transform(image)
+            x = x.unsqueeze(0).to(device)
+
+            # Start
+            torch.cuda.synchronize()
+            start_time = time.perf_counter()   
+
+            # Inference
+            outputs = model(x)
+
+            # End
+            torch.cuda.synchronize()
+            elapsed = time.perf_counter() - start_time
+        
+            if index > 1:
+                total_time += elapsed
+                count += 1
+
+        print('- FPS :', 1.0 / (total_time / count))
+
+if __name__ == '__main__':
+    args = parse_args()
+    # cuda
+    if args.cuda:
+        print('use cuda')
+        device = torch.device("cuda")
+    else:
+        device = torch.device("cpu")
+
+    # Model Config
+    cfg = build_config(args)
+
+    # Transform
+    transform = build_transform(cfg, is_train=False)
+
+    # Dataset
+    args.dataset = 'coco'
+    dataset = build_dataset(args, cfg, transform, is_train=False)
+
+    # Build model
+    model = build_model(args, cfg, is_val=False)
+
+    # Load trained weight
+    model = load_weight(model, args.weight, args.fuse_conv_bn, args.fuse_rep_conv)
+    model.to(device).eval()
+        
+    # Run
+    test_det(model     = model, 
+             device    = device, 
+             dataset   = dataset,
+             transform = transform,
+             )