瀏覽代碼

fix a bug in export_onnx

yjh0410 2 年之前
父節點
當前提交
a344604de0
共有 2 個文件被更改,包括 12 次插入0 次删除
  1. 9 0
      engine.py
  2. 3 0
      tools/export_onnx.py

+ 9 - 0
engine.py

@@ -1298,6 +1298,7 @@ class RTRTrainer(object):
                 
             # Visualize train targets
             if self.args.vis_tgt:
+                targets = self.denormalize_bbox(targets, img_size)
                 vis_data(images*255, targets)
 
             # Inference
@@ -1384,6 +1385,14 @@ class RTRTrainer(object):
         return targets
 
 
+    def denormalize_bbox(self, targets, img_size):
+        # normalize targets
+        for tgt in targets:
+            tgt["boxes"] *= img_size
+        
+        return targets
+
+
     def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
         """
             Deployed for Multi scale trick.

+ 3 - 0
tools/export_onnx.py

@@ -59,6 +59,9 @@ def make_parser():
                         help='topk candidates for testing')
     parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
                         help='fuse Conv & BN')
+    parser.add_argument('--nms_class_agnostic', action='store_true', default=False,
+                        help='Perform NMS operations regardless of category.')
+    
 
     return parser