yjh0410 1 년 전
부모
커밋
7c2ae04a4e
2개의 변경된 파일17개의 추가작업 그리고 16개의 파일을 삭제
  1. 15 14
      yolo/test.py
  2. 2 2
      yolo/utils/misc.py

+ 15 - 14
yolo/test.py

@@ -107,35 +107,30 @@ def test_det(args,
 
 if __name__ == '__main__':
     args = parse_args()
-    # cuda
-    if args.cuda:
+    # Set cuda
+    if args.cuda and torch.cuda.is_available():
         print('use cuda')
         device = torch.device("cuda")
     else:
         device = torch.device("cpu")
 
-    # Dataset & Model Config
+    # Build config
     cfg = build_config(args)
 
-    # Transform
+    # Build data processor
     transform = build_transform(cfg, is_train=False)
 
-    # Dataset
+    # Build dataset
     dataset = build_dataset(args, cfg, transform, is_train=False)
 
-    np.random.seed(0)
-    class_colors = [(np.random.randint(255),
-                     np.random.randint(255),
-                     np.random.randint(255)) for _ in range(cfg.num_classes)]
-
-    # build model
+    # Build model
     model = build_model(args, cfg, is_val=False)
 
-    # load trained weight
+    # Load trained weight
     model = load_weight(model, args.weight, args.fuse_conv_bn)
     model.to(device).eval()
 
-    # compute FLOPs and Params
+    # Compute FLOPs and Params
     model_copy = deepcopy(model)
     model_copy.trainable = False
     model_copy.eval()
@@ -143,7 +138,13 @@ if __name__ == '__main__':
     del model_copy
         
     print("================= DETECT =================")
-    # run
+    # Color for beautiful visualization
+    np.random.seed(0)
+    class_colors = [(np.random.randint(255),
+                     np.random.randint(255),
+                     np.random.randint(255))
+                     for _ in range(cfg.num_classes)]
+    # Run
     test_det(args         = args,
              model        = model, 
              device       = device, 

+ 2 - 2
yolo/utils/misc.py

@@ -365,9 +365,9 @@ def load_weight(model, path_to_ckpt, fuse_cbn=False):
         model = fuse_conv_bn(model)
 
     # Fuse RepConv
-    if hasattr(model, "switch_deploy"):
+    if hasattr(model, "switch_to_deploy"):
         print("Reparam ...")
-        model.switch_deploy()
+        model.switch_to_deploy()
 
     return model