ソースを参照

add ONNX deployment

yjh0410 2 年 前
コミット
3f8c6ecea7

+ 5 - 1
README.md

@@ -139,7 +139,7 @@ I have provided a bash file `train_ddp.sh` that enables DDP training. I hope som
 | Model         |   Backbone         | Scale | Epoch | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
 |---------------|--------------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
 | YOLOX-N       | CSPDarkNet-N       |  640  |  300  |         31.1           |       49.5        |   7.5             |   2.3              | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_n_coco.pth) |
-| YOLOX-S       | CSPDarkNet-S       |  640  |  300  |                        |                   |   27.1            |   9.0              |  |
+| YOLOX-S       | CSPDarkNet-S       |  640  |  300  |                        |                   |   26.8            |   8.9              |  |
 | YOLOX-M       | CSPDarkNet-M       |  640  |  300  |                        |                   |   74.3            |   25.4             |  |
 | YOLOX-L       | CSPDarkNet-L       |  640  |  300  |                        |                   |   155.4           |   54.2             |  |
 
@@ -360,3 +360,7 @@ python track.py --mode video \
 Results:
 
 ![image](./img_files/video_tracking_demo.gif)
+
+
+## Deployment
+1. [ONNX export and an ONNXRuntime](./deployment/ONNXRuntime/)

+ 44 - 0
deployment/ONNXRuntime/README.md

@@ -0,0 +1,44 @@
+## YOLO ONNXRuntime
+
+
+### Convert Your Model to ONNX
+
+First, you should move to <FreeYOLO_HOME> by:
+```shell
+cd <FreeYOLO_HOME>
+cd tools/
+```
+Then, you can:
+
+1. Convert a standard YOLO model by:
+```shell
+python3 export_onnx.py -m yolov1 --weight ../weight/coco/yolov1/yolov1_coco.pth -nc 80 --img_size 640
+```
+
+Notes:
+* -n: specify a model name. The model name must be one of the [yolox-s,m,l,x and yolox-nano, yolox-tiny, yolov3]
+* -c: the model you have trained
+* -o: opset version, default 11. **However, if you will further convert your onnx model to [OpenVINO](https://github.com/Megvii-BaseDetection/YOLOX/demo/OpenVINO/), please specify the opset version to 10.**
+* --no-onnxsim: disable onnxsim
+* To customize an input shape for onnx model,  modify the following code in tools/export_onnx.py:
+
+    ```python
+    dummy_input = torch.randn(args.batch_size, 3, args.img_size, args.img_size)
+    ```
+
+### ONNXRuntime Demo
+
+Step1.
+```shell
+cd <YOLOX_HOME>/deployment/ONNXRuntime
+```
+
+Step2. 
+```shell
+python3 onnx_inference.py --weight ../../weights/onnx/11/yolov1.onnx -i ../test_image.jpg -s 0.3 --img_size 640
+```
+Notes:
+* --weight: your converted onnx model
+* -i: input_image
+* -s: score threshold for visualization.
+* --img_size: should be consistent with the shape you used for onnx convertion.

+ 87 - 0
deployment/ONNXRuntime/onnx_inference.py

@@ -0,0 +1,87 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import argparse
+import os
+
+import cv2
+import time
+import numpy as np
+import sys
+sys.path.append('../../')
+
+import onnxruntime
+from utils.misc import PreProcessor, PostProcessor
+from utils.vis_tools import visualize
+
+
+def make_parser():
+    parser = argparse.ArgumentParser("onnxruntime inference sample")
+    parser.add_argument("-m", "--model", type=str, default="../../weights/onnx/11/yolov1.onnx",
+                        help="Input your onnx model.")
+    parser.add_argument("-i", "--image_path", type=str, default='../test_image.jpg',
+                        help="Path to your input image.")
+    parser.add_argument("-o", "--output_dir", type=str, default='../../det_results/onnx/',
+                        help="Path to your output directory.")
+    parser.add_argument("-s", "--score_thr", type=float, default=0.35,
+                        help="Score threshould to filter the result.")
+    parser.add_argument("-size", "--img_size", type=int, default=640,
+                        help="Specify an input shape for inference.")
+    return parser
+
+
+if __name__ == '__main__':
+    args = make_parser().parse_args()
+
+    # class color for better visualization
+    np.random.seed(0)
+    class_colors = [(np.random.randint(255),
+                     np.random.randint(255),
+                     np.random.randint(255)) for _ in range(80)]
+
+    # preprocessor
+    prepocess = PreProcessor(img_size=args.img_size)
+
+    # postprocessor
+    postprocess = PostProcessor(num_classes=80, conf_thresh=args.score_thr, nms_thresh=0.5)
+
+    # read an image
+    input_shape = tuple([args.img_size, args.img_size])
+    origin_img = cv2.imread(args.image_path)
+
+    # preprocess
+    x, ratio = prepocess(origin_img)
+
+    t0 = time.time()
+    # inference
+    session = onnxruntime.InferenceSession(args.model)
+
+    ort_inputs = {session.get_inputs()[0].name: x[None, :, :, :]}
+    output = session.run(None, ort_inputs)
+    print("inference time: {:.1f} ms".format((time.time() - t0)*1000))
+
+    t0 = time.time()
+    # post process
+    bboxes, scores, labels = postprocess(output[0])
+    bboxes /= ratio
+    print("post-process time: {:.1f} ms".format((time.time() - t0)*1000))
+
+    # visualize detection
+    origin_img = visualize(
+        img=origin_img,
+        bboxes=bboxes,
+        scores=scores,
+        labels=labels,
+        vis_thresh=args.score_thr,
+        class_colors=class_colors
+        )
+
+    # show
+    cv2.imshow('onnx detection', origin_img)
+    cv2.waitKey(0)
+
+    # save results
+    os.makedirs(args.output_dir, exist_ok=True)
+    output_path = os.path.join(args.output_dir, os.path.basename(args.image_path))
+    cv2.imwrite(output_path, origin_img)

BIN
deployment/test_image.jpg


+ 9 - 8
models/detectors/__init__.py

@@ -16,35 +16,36 @@ def build_model(args,
                 model_cfg,
                 device, 
                 num_classes=80, 
-                trainable=False):
+                trainable=False,
+                deploy=False):
     # YOLOv1    
     if args.model == 'yolov1':
         model, criterion = build_yolov1(
-            args, model_cfg, device, num_classes, trainable)
+            args, model_cfg, device, num_classes, trainable, deploy)
     # YOLOv2   
     elif args.model == 'yolov2':
         model, criterion = build_yolov2(
-            args, model_cfg, device, num_classes, trainable)
+            args, model_cfg, device, num_classes, trainable, deploy)
     # YOLOv3   
     elif args.model in ['yolov3', 'yolov3_t']:
         model, criterion = build_yolov3(
-            args, model_cfg, device, num_classes, trainable)
+            args, model_cfg, device, num_classes, trainable, deploy)
     # YOLOv4   
     elif args.model in ['yolov4', 'yolov4_t']:
         model, criterion = build_yolov4(
-            args, model_cfg, device, num_classes, trainable)
+            args, model_cfg, device, num_classes, trainable, deploy)
     # YOLOv5   
     elif args.model in ['yolov5_n', 'yolov5_s', 'yolov5_m', 'yolov5_l', 'yolov5_x']:
         model, criterion = build_yolov5(
-            args, model_cfg, device, num_classes, trainable)
+            args, model_cfg, device, num_classes, trainable, deploy)
     # YOLOv7
     elif args.model in ['yolov7_t', 'yolov7_l', 'yolov7_x']:
         model, criterion = build_yolov7(
-            args, model_cfg, device, num_classes, trainable)
+            args, model_cfg, device, num_classes, trainable, deploy)
     # YOLOX   
     elif args.model in ['yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x']:
         model, criterion = build_yolox(
-            args, model_cfg, device, num_classes, trainable)
+            args, model_cfg, device, num_classes, trainable, deploy)
 
     if trainable:
         # Load pretrained weight

+ 3 - 2
models/detectors/yolov1/build.py

@@ -9,7 +9,7 @@ from .yolov1 import YOLOv1
 
 
 # build object detector
-def build_yolov1(args, cfg, device, num_classes=80, trainable=False):
+def build_yolov1(args, cfg, device, num_classes=80, trainable=False, deploy=False):
     print('==============================')
     print('Build {} ...'.format(args.model.upper()))
     
@@ -24,7 +24,8 @@ def build_yolov1(args, cfg, device, num_classes=80, trainable=False):
         num_classes = num_classes,
         conf_thresh = args.conf_thresh,
         nms_thresh = args.nms_thresh,
-        trainable = trainable
+        trainable = trainable,
+        deploy = deploy
         )
 
     # -------------- Initialize YOLO --------------

+ 15 - 7
models/detectors/yolov1/yolov1.py

@@ -18,7 +18,8 @@ class YOLOv1(nn.Module):
                  num_classes=20,
                  conf_thresh=0.01,
                  nms_thresh=0.5,
-                 trainable=False):
+                 trainable=False,
+                 deploy=False):
         super(YOLOv1, self).__init__()
         # ------------------- Basic parameters -------------------
         self.cfg = cfg                                 # 模型配置文件
@@ -29,6 +30,7 @@ class YOLOv1(nn.Module):
         self.conf_thresh = conf_thresh                 # 得分阈值
         self.nms_thresh = nms_thresh                   # NMS阈值
         self.stride = 32                               # 网络的最大步长
+        self.deploy = deploy
         
         # ------------------- Network Structure -------------------
         ## 主干网络
@@ -148,12 +150,18 @@ class YOLOv1(nn.Module):
         # 解算边界框, 并归一化边界框: [H*W, 4]
         bboxes = self.decode_boxes(reg_pred, fmp_size)
         
-        # 将预测放在cpu处理上,以便进行后处理
-        scores = scores.cpu().numpy()
-        bboxes = bboxes.cpu().numpy()
-        
-        # 后处理
-        bboxes, scores, labels = self.postprocess(bboxes, scores)
+        if self.deploy:
+            # [n_anchors_all, 4 + C]
+            outputs = torch.cat([bboxes, scores], dim=-1)
+
+            return outputs
+        else:
+            # 将预测放在cpu处理上,以便进行后处理
+            scores = scores.cpu().numpy()
+            bboxes = bboxes.cpu().numpy()
+            
+            # 后处理
+            bboxes, scores, labels = self.postprocess(bboxes, scores)
 
         return bboxes, scores, labels
 

+ 2 - 1
models/detectors/yolov2/build.py

@@ -9,7 +9,7 @@ from .yolov2 import YOLOv2
 
 
 # build object detector
-def build_yolov2(args, cfg, device, num_classes=80, trainable=False):
+def build_yolov2(args, cfg, device, num_classes=80, trainable=False, deploy=False):
     print('==============================')
     print('Build {} ...'.format(args.model.upper()))
     
@@ -25,6 +25,7 @@ def build_yolov2(args, cfg, device, num_classes=80, trainable=False):
         conf_thresh=args.conf_thresh,
         nms_thresh=args.nms_thresh,
         topk=args.topk,
+        deploy=deploy
         )
 
     # -------------- Initialize YOLO --------------

+ 16 - 6
models/detectors/yolov2/yolov2.py

@@ -18,7 +18,8 @@ class YOLOv2(nn.Module):
                  conf_thresh=0.01,
                  nms_thresh=0.5,
                  topk=100,
-                 trainable=False):
+                 trainable=False,
+                 deploy=False):
         super(YOLOv2, self).__init__()
         # ------------------- Basic parameters -------------------
         self.cfg = cfg                                 # 模型配置文件
@@ -29,8 +30,9 @@ class YOLOv2(nn.Module):
         self.nms_thresh = nms_thresh                   # NMS阈值
         self.topk = topk                               # topk
         self.stride = 32                               # 网络的最大步长
+        self.deploy = deploy
         # ------------------- Anchor box -------------------
-        self.anchor_size = torch.as_tensor(cfg['anchor_size']).view(-1, 2) # [A, 2]
+        self.anchor_size = torch.as_tensor(cfg['anchor_size']).float().view(-1, 2) # [A, 2]
         self.num_anchors = self.anchor_size.shape[0]
         
         # ------------------- Network Structure -------------------
@@ -179,11 +181,19 @@ class YOLOv2(nn.Module):
         cls_pred = cls_pred[0]       # [H*W*A, NC]
         reg_pred = reg_pred[0]       # [H*W*A, 4]
 
-        # post process
-        bboxes, scores, labels = self.postprocess(
-            obj_pred, cls_pred, reg_pred, anchors)
+        if self.deploy:
+            scores = torch.sqrt(obj_pred.sigmoid() * cls_pred.sigmoid())
+            bboxes = self.decode_boxes(anchors, reg_pred)
+            # [n_anchors_all, 4 + C]
+            outputs = torch.cat([bboxes, scores], dim=-1)
 
-        return bboxes, scores, labels
+            return outputs
+        else:
+            # post process
+            bboxes, scores, labels = self.postprocess(
+                obj_pred, cls_pred, reg_pred, anchors)
+
+            return bboxes, scores, labels
 
 
     def forward(self, x):

+ 2 - 1
models/detectors/yolov3/build.py

@@ -9,7 +9,7 @@ from .yolov3 import YOLOv3
 
 
 # build object detector
-def build_yolov3(args, cfg, device, num_classes=80, trainable=False):
+def build_yolov3(args, cfg, device, num_classes=80, trainable=False, deploy=False):
     print('==============================')
     print('Build {} ...'.format(args.model.upper()))
     
@@ -25,6 +25,7 @@ def build_yolov3(args, cfg, device, num_classes=80, trainable=False):
         conf_thresh=args.conf_thresh,
         nms_thresh=args.nms_thresh,
         topk=args.topk,
+        deploy = deploy
         )
 
     # -------------- Initialize YOLO --------------

+ 19 - 6
models/detectors/yolov3/yolov3.py

@@ -18,7 +18,8 @@ class YOLOv3(nn.Module):
                  conf_thresh=0.01,
                  topk=100,
                  nms_thresh=0.5,
-                 trainable=False):
+                 trainable=False,
+                 deploy=False):
         super(YOLOv3, self).__init__()
         # ------------------- Basic parameters -------------------
         self.cfg = cfg                                 # 模型配置文件
@@ -29,12 +30,13 @@ class YOLOv3(nn.Module):
         self.nms_thresh = nms_thresh                   # NMS阈值
         self.topk = topk                               # topk
         self.stride = [8, 16, 32]                      # 网络的输出步长
+        self.deploy = deploy
         # ------------------- Anchor box -------------------
         self.num_levels = 3
         self.num_anchors = len(cfg['anchor_size']) // self.num_levels
         self.anchor_size = torch.as_tensor(
             cfg['anchor_size']
-            ).view(self.num_levels, self.num_anchors, 2) # [S, A, 2]
+            ).float().view(self.num_levels, self.num_anchors, 2) # [S, A, 2]
         
         # ------------------- Network Structure -------------------
         ## 主干网络
@@ -196,11 +198,22 @@ class YOLOv3(nn.Module):
             all_box_preds.append(box_pred)
             all_anchors.append(anchors)
 
-        # post process
-        bboxes, scores, labels = self.post_process(
-            all_obj_preds, all_cls_preds, all_box_preds)
+        if self.deploy:
+            obj_preds = torch.cat(all_obj_preds, dim=0)
+            cls_preds = torch.cat(all_cls_preds, dim=0)
+            box_preds = torch.cat(all_box_preds, dim=0)
+            scores = torch.sqrt(obj_preds.sigmoid() * cls_preds.sigmoid())
+            bboxes = box_preds
+            # [n_anchors_all, 4 + C]
+            outputs = torch.cat([bboxes, scores], dim=-1)
+
+            return outputs
+        else:
+            # post process
+            bboxes, scores, labels = self.post_process(
+                all_obj_preds, all_cls_preds, all_box_preds)
         
-        return bboxes, scores, labels
+            return bboxes, scores, labels
 
 
     # ---------------------- Main Process for Training ----------------------

+ 2 - 1
models/detectors/yolov4/build.py

@@ -9,7 +9,7 @@ from .yolov4 import YOLOv4
 
 
 # build object detector
-def build_yolov4(args, cfg, device, num_classes=80, trainable=False):
+def build_yolov4(args, cfg, device, num_classes=80, trainable=False, deploy=False):
     print('==============================')
     print('Build {} ...'.format(args.model.upper()))
     
@@ -25,6 +25,7 @@ def build_yolov4(args, cfg, device, num_classes=80, trainable=False):
         conf_thresh=args.conf_thresh,
         nms_thresh=args.nms_thresh,
         topk=args.topk,
+        deploy = deploy
         )
 
     # -------------- Initialize YOLO --------------

+ 19 - 6
models/detectors/yolov4/yolov4.py

@@ -18,7 +18,8 @@ class YOLOv4(nn.Module):
                  conf_thresh=0.01,
                  nms_thresh=0.5,
                  topk=100,
-                 trainable=False):
+                 trainable=False,
+                 deploy=False):
         super(YOLOv4, self).__init__()
         # ------------------- Basic parameters -------------------
         self.cfg = cfg                                 # 模型配置文件
@@ -29,12 +30,13 @@ class YOLOv4(nn.Module):
         self.nms_thresh = nms_thresh                   # NMS阈值
         self.topk = topk                               # topk
         self.stride = [8, 16, 32]                      # 网络的输出步长
+        self.deploy = deploy
         # ------------------- Anchor box -------------------
         self.num_levels = 3
         self.num_anchors = len(cfg['anchor_size']) // self.num_levels
         self.anchor_size = torch.as_tensor(
             cfg['anchor_size']
-            ).view(self.num_levels, self.num_anchors, 2) # [S, A, 2]
+            ).float().view(self.num_levels, self.num_anchors, 2) # [S, A, 2]
         
         # ------------------- Network Structure -------------------
         ## 主干网络
@@ -196,11 +198,22 @@ class YOLOv4(nn.Module):
             all_box_preds.append(box_pred)
             all_anchors.append(anchors)
 
-        # post process
-        bboxes, scores, labels = self.post_process(
-            all_obj_preds, all_cls_preds, all_box_preds)
+        if self.deploy:
+            obj_preds = torch.cat(all_obj_preds, dim=0)
+            cls_preds = torch.cat(all_cls_preds, dim=0)
+            box_preds = torch.cat(all_box_preds, dim=0)
+            scores = torch.sqrt(obj_preds.sigmoid() * cls_preds.sigmoid())
+            bboxes = box_preds
+            # [n_anchors_all, 4 + C]
+            outputs = torch.cat([bboxes, scores], dim=-1)
+
+            return outputs
+        else:
+            # post process
+            bboxes, scores, labels = self.post_process(
+                all_obj_preds, all_cls_preds, all_box_preds)
         
-        return bboxes, scores, labels
+            return bboxes, scores, labels
 
 
     # ---------------------- Main Process for Training ----------------------

+ 3 - 2
models/detectors/yolov5/build.py

@@ -9,7 +9,7 @@ from .yolov5 import YOLOv5
 
 
 # build object detector
-def build_yolov5(args, cfg, device, num_classes=80, trainable=False):
+def build_yolov5(args, cfg, device, num_classes=80, trainable=False, deploy=False):
     print('==============================')
     print('Build {} ...'.format(args.model.upper()))
     
@@ -21,10 +21,11 @@ def build_yolov5(args, cfg, device, num_classes=80, trainable=False):
         cfg=cfg,
         device=device, 
         num_classes=num_classes,
-        trainable=trainable,
         conf_thresh=args.conf_thresh,
         nms_thresh=args.nms_thresh,
         topk=args.topk,
+        trainable = trainable,
+        deploy = deploy
         )
 
     # -------------- Initialize YOLO --------------

+ 19 - 6
models/detectors/yolov5/yolov5.py

@@ -16,7 +16,8 @@ class YOLOv5(nn.Module):
                  conf_thresh = 0.05,
                  nms_thresh = 0.6,
                  trainable = False, 
-                 topk = 1000):
+                 topk = 1000,
+                 deploy = False):
         super(YOLOv5, self).__init__()
         # ---------------------- Basic Parameters ----------------------
         self.cfg = cfg
@@ -27,13 +28,14 @@ class YOLOv5(nn.Module):
         self.conf_thresh = conf_thresh
         self.nms_thresh = nms_thresh
         self.topk = topk
+        self.deploy = deploy
         
         # ------------------- Anchor box -------------------
         self.num_levels = 3
         self.num_anchors = len(cfg['anchor_size']) // self.num_levels
         self.anchor_size = torch.as_tensor(
             cfg['anchor_size']
-            ).view(self.num_levels, self.num_anchors, 2) # [S, A, 2]
+            ).float().view(self.num_levels, self.num_anchors, 2) # [S, A, 2]
         
         # ------------------- Network Structure -------------------
         ## Backbone
@@ -184,11 +186,22 @@ class YOLOv5(nn.Module):
             all_box_preds.append(box_pred)
             all_anchors.append(anchors)
 
-        # post process
-        bboxes, scores, labels = self.post_process(
-            all_obj_preds, all_cls_preds, all_box_preds)
+        if self.deploy:
+            obj_preds = torch.cat(all_obj_preds, dim=0)
+            cls_preds = torch.cat(all_cls_preds, dim=0)
+            box_preds = torch.cat(all_box_preds, dim=0)
+            scores = torch.sqrt(obj_preds.sigmoid() * cls_preds.sigmoid())
+            bboxes = box_preds
+            # [n_anchors_all, 4 + C]
+            outputs = torch.cat([bboxes, scores], dim=-1)
+
+            return outputs
+        else:
+            # post process
+            bboxes, scores, labels = self.post_process(
+                all_obj_preds, all_cls_preds, all_box_preds)
         
-        return bboxes, scores, labels
+            return bboxes, scores, labels
 
 
     # ---------------------- Main Process for Training ----------------------

+ 3 - 2
models/detectors/yolov7/build.py

@@ -9,7 +9,7 @@ from .yolov7 import YOLOv7
 
 
 # build object detector
-def build_yolov7(args, cfg, device, num_classes=80, trainable=False):
+def build_yolov7(args, cfg, device, num_classes=80, trainable=False, deploy=False):
     print('==============================')
     print('Build {} ...'.format(args.model.upper()))
     
@@ -24,7 +24,8 @@ def build_yolov7(args, cfg, device, num_classes=80, trainable=False):
         conf_thresh = args.conf_thresh,
         nms_thresh = args.nms_thresh,
         topk = args.topk,
-        trainable = trainable
+        trainable = trainable,
+        deploy = deploy
         )
 
     # -------------- Initialize YOLO --------------

+ 18 - 5
models/detectors/yolov7/yolov7.py

@@ -18,7 +18,8 @@ class YOLOv7(nn.Module):
                  conf_thresh=0.01,
                  topk=100,
                  nms_thresh=0.5,
-                 trainable=False):
+                 trainable=False,
+                 deploy = False):
         super(YOLOv7, self).__init__()
         # ------------------- Basic parameters -------------------
         self.cfg = cfg                                 # 模型配置文件
@@ -29,6 +30,7 @@ class YOLOv7(nn.Module):
         self.nms_thresh = nms_thresh                   # NMS阈值
         self.topk = topk                               # topk
         self.stride = [8, 16, 32]                      # 网络的输出步长        
+        self.deploy = deploy
         # ------------------- Network Structure -------------------
         ## 主干网络
         self.backbone, feats_dim = build_backbone(cfg, trainable&cfg['pretrained'])
@@ -180,11 +182,22 @@ class YOLOv7(nn.Module):
             all_box_preds.append(box_pred)
             all_anchors.append(anchors)
 
-        # post process
-        bboxes, scores, labels = self.post_process(
-            all_obj_preds, all_cls_preds, all_box_preds)
+        if self.deploy:
+            obj_preds = torch.cat(all_obj_preds, dim=0)
+            cls_preds = torch.cat(all_cls_preds, dim=0)
+            box_preds = torch.cat(all_box_preds, dim=0)
+            scores = torch.sqrt(obj_preds.sigmoid() * cls_preds.sigmoid())
+            bboxes = box_preds
+            # [n_anchors_all, 4 + C]
+            outputs = torch.cat([bboxes, scores], dim=-1)
+
+            return outputs
+        else:
+            # post process
+            bboxes, scores, labels = self.post_process(
+                all_obj_preds, all_cls_preds, all_box_preds)
         
-        return bboxes, scores, labels
+            return bboxes, scores, labels
 
 
     # ---------------------- Main Process for Training ----------------------

+ 2 - 1
models/detectors/yolox/build.py

@@ -9,7 +9,7 @@ from .yolox import YOLOX
 
 
 # build object detector
-def build_yolox(args, cfg, device, num_classes=80, trainable=False):
+def build_yolox(args, cfg, device, num_classes=80, trainable=False, deploy=False):
     print('==============================')
     print('Build {} ...'.format(args.model.upper()))
     
@@ -25,6 +25,7 @@ def build_yolox(args, cfg, device, num_classes=80, trainable=False):
         conf_thresh=args.conf_thresh,
         nms_thresh=args.nms_thresh,
         topk=args.topk,
+        deploy=deploy
         )
 
     # -------------- Initialize YOLO --------------

+ 18 - 5
models/detectors/yolox/yolox.py

@@ -17,7 +17,8 @@ class YOLOX(nn.Module):
                  conf_thresh=0.01,
                  nms_thresh=0.5,
                  topk=100,
-                 trainable=False):
+                 trainable=False,
+                 deploy = False):
         super(YOLOX, self).__init__()
         # --------- Basic Parameters ----------
         self.cfg = cfg
@@ -28,6 +29,7 @@ class YOLOX(nn.Module):
         self.conf_thresh = conf_thresh
         self.nms_thresh = nms_thresh
         self.topk = topk
+        self.deploy = deploy
         
         # ------------------- Network Structure -------------------
         ## 主干网络
@@ -172,11 +174,22 @@ class YOLOX(nn.Module):
             all_box_preds.append(box_pred)
             all_anchors.append(anchors)
 
-        # post process
-        bboxes, scores, labels = self.post_process(
-            all_obj_preds, all_cls_preds, all_box_preds)
+        if self.deploy:
+            obj_preds = torch.cat(all_obj_preds, dim=0)
+            cls_preds = torch.cat(all_cls_preds, dim=0)
+            box_preds = torch.cat(all_box_preds, dim=0)
+            scores = torch.sqrt(obj_preds.sigmoid() * cls_preds.sigmoid())
+            bboxes = box_preds
+            # [n_anchors_all, 4 + C]
+            outputs = torch.cat([bboxes, scores], dim=-1)
+
+            return outputs
+        else:
+            # post process
+            bboxes, scores, labels = self.post_process(
+                all_obj_preds, all_cls_preds, all_box_preds)
         
-        return bboxes, scores, labels
+            return bboxes, scores, labels
 
 
     def forward(self, x):

+ 16 - 22
tools/export_onnx.py

@@ -14,15 +14,14 @@ from torch import nn
 
 from utils.misc import SiLU
 from utils.misc import load_weight, replace_module
-from config import build_config
+
+from config import  build_model_config
 from models.detectors import build_model
 
 
 def make_parser():
     parser = argparse.ArgumentParser("YOLO ONNXRuntime")
     # basic
-    parser.add_argument("--output-name", type=str, default="yolo_free_large.onnx",
-                        help="output name of models")
     parser.add_argument('-size', '--img_size', default=640, type=int,
                         help='the max size of input image')
     parser.add_argument("--input", default="images", type=str,
@@ -48,7 +47,7 @@ def make_parser():
                         help='Dir to save onnx file')
 
     # model
-    parser.add_argument('-v', '--version', default='yolo_free_large', type=str,
+    parser.add_argument('-m', '--model', default='yolov1', type=str,
                         help='build yolo')
     parser.add_argument('--weight', default=None,
                         type=str, help='Trained state_dict file path to open')
@@ -72,17 +71,11 @@ def main():
     logger.info("args value: {}".format(args))
     device = torch.device('cpu')
 
-    # config
-    cfg = build_config(args)
+    # Dataset & Model Config
+    model_cfg = build_model_config(args)
 
     # build model
-    model = build_model(
-        args=args, 
-        cfg=cfg,
-        device=device, 
-        num_classes=args.num_classes,
-        trainable=False
-        )
+    model = build_model(args, model_cfg, device, args.num_classes, False, deploy=True)
 
     # replace nn.SiLU with SiLU
     model = replace_module(model, nn.SiLU, SiLU)
@@ -97,20 +90,21 @@ def main():
     # save onnx file
     save_path = os.path.join(args.save_dir, str(args.opset))
     os.makedirs(save_path, exist_ok=True)
-    output_name = os.path.join(save_path, args.output_name)
+    output_name = os.path.join(args.model + '.onnx')
+    output_path = os.path.join(save_path, output_name)
 
     torch.onnx._export(
         model,
         dummy_input,
-        output_name,
+        output_path,
         input_names=[args.input],
-        output_names=[args.output],
+        output_names=[output_name],
         dynamic_axes={args.input: {0: 'batch'},
-                      args.output: {0: 'batch'}} if args.dynamic else None,
+                      output_name: {0: 'batch'}} if args.dynamic else None,
         opset_version=args.opset,
     )
 
-    logger.info("generated onnx model named {}".format(output_name))
+    logger.info("generated onnx model named {}".format(output_path))
 
     if not args.no_onnxsim:
         import onnx
@@ -120,7 +114,7 @@ def main():
         input_shapes = {args.input: list(dummy_input.shape)} if args.dynamic else None
 
         # use onnxsimplify to reduce reduent model.
-        onnx_model = onnx.load(output_name)
+        onnx_model = onnx.load(output_path)
         model_simp, check = simplify(onnx_model,
                                      dynamic_input_shape=args.dynamic,
                                      input_shapes=input_shapes)
@@ -129,9 +123,9 @@ def main():
         # save onnxsim file
         save_path = os.path.join(save_path, 'onnxsim')
         os.makedirs(save_path, exist_ok=True)
-        output_name = os.path.join(save_path, args.output_name)
-        onnx.save(model_simp, output_name)
-        logger.info("generated simplified onnx model named {}".format(output_name))
+        output_path = os.path.join(save_path, output_name)
+        onnx.save(model_simp, output_path)
+        logger.info("generated simplified onnx model named {}".format(output_path))
 
 
 if __name__ == "__main__":

+ 3 - 7
utils/misc.py

@@ -323,16 +323,14 @@ class PreProcessor(object):
         
         # [H, W, C] -> [C, H, W]
         padded_img = padded_img.transpose(swap)
-        padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
+        padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) / 255.
 
 
         return padded_img, r
 
 ## Post-processer
 class PostProcessor(object):
-    def __init__(self, img_size, strides, num_classes, conf_thresh=0.15, nms_thresh=0.5):
-        self.img_size = img_size
-        self.strides = strides
+    def __init__(self, num_classes, conf_thresh=0.15, nms_thresh=0.5):
         self.num_classes = num_classes
         self.conf_thresh = conf_thresh
         self.nms_thresh = nms_thresh
@@ -344,9 +342,7 @@ class PostProcessor(object):
             predictions: (ndarray) [n_anchors_all, 4+1+C]
         """
         bboxes = predictions[..., :4]
-        obj_preds = predictions[..., 4:5]
-        cls_preds = predictions[..., 5:]
-        scores = np.sqrt(obj_preds * cls_preds)
+        scores = predictions[..., 4:]
 
         # scores & labels
         labels = np.argmax(scores, axis=1)                      # [M,]