test.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import argparse
  2. import cv2
  3. import os
  4. import time
  5. import numpy as np
  6. from copy import deepcopy
  7. import torch
  8. # load transform
  9. from datasets import build_dataset, build_transform
  10. # load some utils
  11. from utils.misc import load_weight, compute_flops
  12. from utils.vis_tools import visualize
  13. from config import build_config
  14. from models.detectors import build_model
  15. def parse_args():
  16. parser = argparse.ArgumentParser(description='Object Detection Lab')
  17. # Basic
  18. parser.add_argument('--cuda', action='store_true', default=False,
  19. help='use cuda.')
  20. parser.add_argument('--show', action='store_true', default=False,
  21. help='show the visulization results.')
  22. parser.add_argument('--save', action='store_true', default=False,
  23. help='save the visulization results.')
  24. parser.add_argument('--save_folder', default='det_results/', type=str,
  25. help='Dir to save results')
  26. parser.add_argument('-vt', '--visual_threshold', default=0.3, type=float,
  27. help='Final confidence threshold')
  28. parser.add_argument('-ws', '--window_scale', default=1.0, type=float,
  29. help='resize window of cv2 for visualization.')
  30. # Model
  31. parser.add_argument('-m', '--model', default='yolof_r18_c5_1x', type=str,
  32. help='build detector')
  33. parser.add_argument('--weight', default=None,
  34. type=str, help='Trained state_dict file path to open')
  35. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  36. help='fuse Conv & BN')
  37. # Dataset
  38. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
  39. help='data root')
  40. parser.add_argument('-d', '--dataset', default='coco',
  41. help='coco, voc.')
  42. return parser.parse_args()
  43. @torch.no_grad()
  44. def test_det(args, model, device, dataset, transform, class_colors, class_names):
  45. num_images = len(dataset)
  46. save_path = os.path.join('det_results/', args.dataset, args.model)
  47. os.makedirs(save_path, exist_ok=True)
  48. for index, (image, _) in enumerate(dataset):
  49. print('Testing image {:d}/{:d}....'.format(index+1, num_images))
  50. orig_h, orig_w = image.height, image.width
  51. # PreProcess
  52. x, _ = transform(image)
  53. x = x.unsqueeze(0).to(device)
  54. # Inference
  55. t0 = time.time()
  56. outputs = model(x)
  57. scores = outputs['scores']
  58. labels = outputs['labels']
  59. bboxes = outputs['bboxes']
  60. print("Infer. time: {}".format(time.time() - t0, "s"))
  61. # Rescale bboxes
  62. bboxes[..., 0::2] *= orig_w
  63. bboxes[..., 1::2] *= orig_h
  64. # vis detection
  65. img_processed = visualize(image=image,
  66. bboxes=bboxes,
  67. scores=scores,
  68. labels=labels,
  69. class_colors=class_colors,
  70. class_names=class_names)
  71. if args.show:
  72. h, w = img_processed.shape[:2]
  73. sw, sh = int(w*args.window_scale), int(h*args.window_scale)
  74. cv2.namedWindow('detection', 0)
  75. cv2.resizeWindow('detection', sw, sh)
  76. cv2.imshow('detection', img_processed)
  77. cv2.waitKey(0)
  78. if args.save:
  79. # save result
  80. cv2.imwrite(os.path.join(save_path, str(index).zfill(6) +'.jpg'), img_processed)
  81. if __name__ == '__main__':
  82. args = parse_args()
  83. # cuda
  84. if args.cuda:
  85. print('use cuda')
  86. device = torch.device("cuda")
  87. else:
  88. device = torch.device("cpu")
  89. # Dataset & Model Config
  90. cfg = build_config(args)
  91. # Transform
  92. transform = build_transform(cfg, is_train=False)
  93. # Dataset
  94. dataset = build_dataset(args, cfg, is_train=False)
  95. np.random.seed(0)
  96. class_colors = [(np.random.randint(255),
  97. np.random.randint(255),
  98. np.random.randint(255)) for _ in range(cfg.num_classes)]
  99. # Model
  100. model = build_model(args, cfg, is_val=False)
  101. model = load_weight(model, args.weight, args.fuse_conv_bn)
  102. model.to(device).eval()
  103. # Compute FLOPs and Params
  104. model_copy = deepcopy(model)
  105. model_copy.trainable = False
  106. model_copy.eval()
  107. compute_flops(
  108. model=model_copy,
  109. min_size=cfg['test_min_size'],
  110. max_size=cfg['test_max_size'],
  111. device=device)
  112. del model_copy
  113. print("================= DETECT =================")
  114. # run
  115. test_det(args = args,
  116. model = model,
  117. device = device,
  118. dataset = dataset,
  119. transform = transform,
  120. class_colors = class_colors,
  121. class_names = cfg.class_labels,
  122. )