|
|
@@ -70,10 +70,6 @@ def parse_args():
|
|
|
parser.add_argument('--load_cache', action='store_true', default=False,
|
|
|
help='load data into memory.')
|
|
|
|
|
|
- # Task setting
|
|
|
- parser.add_argument('-t', '--task', default='det', choices=['det', 'det_seg', 'det_pos', 'det_seg_pos'],
|
|
|
- help='task type.')
|
|
|
-
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
@@ -131,18 +127,6 @@ def test_det(args,
|
|
|
# save result
|
|
|
cv2.imwrite(os.path.join(save_path, str(index).zfill(6) +'.jpg'), img_processed)
|
|
|
|
|
|
-@torch.no_grad()
|
|
|
-def test_det_seg():
|
|
|
- pass
|
|
|
-
|
|
|
-@torch.no_grad()
|
|
|
-def test_det_pos():
|
|
|
- pass
|
|
|
-
|
|
|
-@torch.no_grad()
|
|
|
-def test_det_seg_pos():
|
|
|
- pass
|
|
|
-
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
args = parse_args()
|
|
|
@@ -199,19 +183,12 @@ if __name__ == '__main__':
|
|
|
|
|
|
print("================= DETECT =================")
|
|
|
# run
|
|
|
- if args.task == "det":
|
|
|
- test_det(args=args,
|
|
|
- model=model,
|
|
|
- device=device,
|
|
|
- dataset=dataset,
|
|
|
- transform=val_transform,
|
|
|
- class_colors=class_colors,
|
|
|
- class_names=dataset_info['class_names'],
|
|
|
- class_indexs=dataset_info['class_indexs'],
|
|
|
- )
|
|
|
- elif args.task == "det_seg":
|
|
|
- test_det_seg()
|
|
|
- elif args.task == "det_pos":
|
|
|
- test_det_pos()
|
|
|
- elif args.task == "det_seg_pos":
|
|
|
- test_det_seg_pos()
|
|
|
+ test_det(args=args,
|
|
|
+ model=model,
|
|
|
+ device=device,
|
|
|
+ dataset=dataset,
|
|
|
+ transform=val_transform,
|
|
|
+ class_colors=class_colors,
|
|
|
+ class_names=dataset_info['class_names'],
|
|
|
+ class_indexs=dataset_info['class_indexs'],
|
|
|
+ )
|