|
|
@@ -1,12 +1,12 @@
|
|
|
import argparse
|
|
|
-import numpy as np
|
|
|
import time
|
|
|
-import os
|
|
|
import torch
|
|
|
|
|
|
+# load transform
|
|
|
from datasets import build_dataset, build_transform
|
|
|
-from utils.misc import compute_flops, fuse_conv_bn
|
|
|
-from utils.misc import load_weight
|
|
|
+
|
|
|
+# load some utils
|
|
|
+from utils.misc import compute_flops, load_weight
|
|
|
|
|
|
from config import build_config
|
|
|
from models.detectors import build_model
|
|
|
@@ -18,8 +18,6 @@ parser.add_argument('-m', '--model', default='fcos_r18_1x',
|
|
|
help='build detector')
|
|
|
parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
|
|
|
help='fuse conv and bn')
|
|
|
-parser.add_argument('--topk', default=100, type=int,
|
|
|
- help='NMS threshold')
|
|
|
parser.add_argument('--weight', default=None, type=str,
|
|
|
help='Trained state_dict file path to open')
|
|
|
# Data root
|
|
|
@@ -36,8 +34,8 @@ def test(cfg, model, device, dataset, transform):
|
|
|
# Step-1: Compute FLOPs and Params
|
|
|
compute_flops(
|
|
|
model=model,
|
|
|
- min_size=cfg['test_min_size'],
|
|
|
- max_size=cfg['test_max_size'],
|
|
|
+ min_size=cfg.test_min_size,
|
|
|
+ max_size=cfg.test_max_size,
|
|
|
device=device)
|
|
|
|
|
|
# Step-2: Compute FPS
|
|
|
@@ -48,29 +46,25 @@ def test(cfg, model, device, dataset, transform):
|
|
|
for index in range(num_images):
|
|
|
if index % 500 == 0:
|
|
|
print('Testing image {:d}/{:d}....'.format(index+1, num_images))
|
|
|
+
|
|
|
+ # Load an image
|
|
|
image, _ = dataset[index]
|
|
|
- orig_h, orig_w = image.height, image.width
|
|
|
|
|
|
- # PreProcess
|
|
|
+ # Preprocess
|
|
|
x, _ = transform(image)
|
|
|
x = x.unsqueeze(0).to(device)
|
|
|
|
|
|
- # star time
|
|
|
+ # Star
|
|
|
torch.cuda.synchronize()
|
|
|
start_time = time.perf_counter()
|
|
|
|
|
|
- # inference
|
|
|
- bboxes, scores, labels = model(x)
|
|
|
+ # Inference
|
|
|
+ outputs = model(x)
|
|
|
|
|
|
- # Rescale bboxes
|
|
|
- bboxes[..., 0::2] *= orig_w
|
|
|
- bboxes[..., 1::2] *= orig_h
|
|
|
-
|
|
|
- # end time
|
|
|
+ # End
|
|
|
torch.cuda.synchronize()
|
|
|
elapsed = time.perf_counter() - start_time
|
|
|
|
|
|
- # print("detection time used ", elapsed, "s")
|
|
|
if index > 1:
|
|
|
total_time += elapsed
|
|
|
count += 1
|
|
|
@@ -95,17 +89,13 @@ if __name__ == '__main__':
|
|
|
|
|
|
# Dataset
|
|
|
args.dataset = 'coco'
|
|
|
- dataset, dataset_info = build_dataset(args, is_train=False)
|
|
|
+ dataset = build_dataset(args, cfg, is_train=False)
|
|
|
|
|
|
# Model
|
|
|
- model = build_model(args, cfg, device, dataset_info['num_classes'], False)
|
|
|
+ model = build_model(args, cfg, is_val=False)
|
|
|
model = load_weight(model, args.weight, args.fuse_conv_bn)
|
|
|
model.to(device).eval()
|
|
|
|
|
|
- # fuse conv bn
|
|
|
- if args.fuse_conv_bn:
|
|
|
- print('fuse conv and bn ...')
|
|
|
- model = fuse_conv_bn(model)
|
|
|
-
|
|
|
- # run
|
|
|
+ print("================= DETECT =================")
|
|
|
+ # Run
|
|
|
test(cfg, model, device, dataset, transform)
|