|
@@ -107,35 +107,30 @@ def test_det(args,
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
|
args = parse_args()
|
|
args = parse_args()
|
|
|
- # cuda
|
|
|
|
|
- if args.cuda:
|
|
|
|
|
|
|
+ # Set cuda
|
|
|
|
|
+ if args.cuda and torch.cuda.is_available():
|
|
|
print('use cuda')
|
|
print('use cuda')
|
|
|
device = torch.device("cuda")
|
|
device = torch.device("cuda")
|
|
|
else:
|
|
else:
|
|
|
device = torch.device("cpu")
|
|
device = torch.device("cpu")
|
|
|
|
|
|
|
|
- # Dataset & Model Config
|
|
|
|
|
|
|
+ # Build config
|
|
|
cfg = build_config(args)
|
|
cfg = build_config(args)
|
|
|
|
|
|
|
|
- # Transform
|
|
|
|
|
|
|
+ # Build data processor
|
|
|
transform = build_transform(cfg, is_train=False)
|
|
transform = build_transform(cfg, is_train=False)
|
|
|
|
|
|
|
|
- # Dataset
|
|
|
|
|
|
|
+ # Build dataset
|
|
|
dataset = build_dataset(args, cfg, transform, is_train=False)
|
|
dataset = build_dataset(args, cfg, transform, is_train=False)
|
|
|
|
|
|
|
|
- np.random.seed(0)
|
|
|
|
|
- class_colors = [(np.random.randint(255),
|
|
|
|
|
- np.random.randint(255),
|
|
|
|
|
- np.random.randint(255)) for _ in range(cfg.num_classes)]
|
|
|
|
|
-
|
|
|
|
|
- # build model
|
|
|
|
|
|
|
+ # Build model
|
|
|
model = build_model(args, cfg, is_val=False)
|
|
model = build_model(args, cfg, is_val=False)
|
|
|
|
|
|
|
|
- # load trained weight
|
|
|
|
|
|
|
+ # Load trained weight
|
|
|
model = load_weight(model, args.weight, args.fuse_conv_bn)
|
|
model = load_weight(model, args.weight, args.fuse_conv_bn)
|
|
|
model.to(device).eval()
|
|
model.to(device).eval()
|
|
|
|
|
|
|
|
- # compute FLOPs and Params
|
|
|
|
|
|
|
+ # Compute FLOPs and Params
|
|
|
model_copy = deepcopy(model)
|
|
model_copy = deepcopy(model)
|
|
|
model_copy.trainable = False
|
|
model_copy.trainable = False
|
|
|
model_copy.eval()
|
|
model_copy.eval()
|
|
@@ -143,7 +138,13 @@ if __name__ == '__main__':
|
|
|
del model_copy
|
|
del model_copy
|
|
|
|
|
|
|
|
print("================= DETECT =================")
|
|
print("================= DETECT =================")
|
|
|
- # run
|
|
|
|
|
|
|
+ # Color for beautiful visualization
|
|
|
|
|
+ np.random.seed(0)
|
|
|
|
|
+ class_colors = [(np.random.randint(255),
|
|
|
|
|
+ np.random.randint(255),
|
|
|
|
|
+ np.random.randint(255))
|
|
|
|
|
+ for _ in range(cfg.num_classes)]
|
|
|
|
|
+ # Run
|
|
|
test_det(args = args,
|
|
test_det(args = args,
|
|
|
model = model,
|
|
model = model,
|
|
|
device = device,
|
|
device = device,
|