demo.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. import argparse
  2. import cv2
  3. import os
  4. import time
  5. import numpy as np
  6. import imageio
  7. import torch
  8. # load transform
  9. from dataset.build import build_transform
  10. # load some utils
  11. from utils.misc import load_weight
  12. from utils.box_ops import rescale_bboxes
  13. from utils.vis_tools import visualize
  14. from models.detectors import build_model
  15. from config import build_model_config, build_trans_config
  16. def parse_args():
  17. parser = argparse.ArgumentParser(description='YOLO Demo')
  18. # basic
  19. parser.add_argument('-size', '--img_size', default=640, type=int,
  20. help='the max size of input image')
  21. parser.add_argument('--mosaic', default=None, type=float,
  22. help='mosaic augmentation.')
  23. parser.add_argument('--mixup', default=None, type=float,
  24. help='mixup augmentation.')
  25. parser.add_argument('--mode', default='image',
  26. type=str, help='Use the data from image, video or camera')
  27. parser.add_argument('--cuda', action='store_true', default=False,
  28. help='Use cuda')
  29. parser.add_argument('--path_to_img', default='dataset/demo/images/',
  30. type=str, help='The path to image files')
  31. parser.add_argument('--path_to_vid', default='dataset/demo/videos/',
  32. type=str, help='The path to video files')
  33. parser.add_argument('--path_to_save', default='det_results/demos/',
  34. type=str, help='The path to save the detection results')
  35. parser.add_argument('-vt', '--vis_thresh', default=0.4, type=float,
  36. help='Final confidence threshold for visualization')
  37. parser.add_argument('--show', action='store_true', default=False,
  38. help='show visualization')
  39. parser.add_argument('--gif', action='store_true', default=False,
  40. help='generate gif.')
  41. # model
  42. parser.add_argument('-m', '--model', default='yolov1', type=str,
  43. help='build yolo')
  44. parser.add_argument('-nc', '--num_classes', default=80, type=int,
  45. help='number of classes.')
  46. parser.add_argument('--weight', default=None,
  47. type=str, help='Trained state_dict file path to open')
  48. parser.add_argument('-ct', '--conf_thresh', default=0.1, type=float,
  49. help='confidence threshold')
  50. parser.add_argument('-nt', '--nms_thresh', default=0.5, type=float,
  51. help='NMS threshold')
  52. parser.add_argument('--topk', default=100, type=int,
  53. help='topk candidates for testing')
  54. parser.add_argument("--deploy", action="store_true", default=False,
  55. help="deploy mode or not")
  56. parser.add_argument('--fuse_repconv', action='store_true', default=False,
  57. help='fuse RepConv')
  58. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  59. help='fuse Conv & BN')
  60. return parser.parse_args()
  61. def detect(args,
  62. model,
  63. device,
  64. transform,
  65. vis_thresh,
  66. mode='image'):
  67. # class color
  68. np.random.seed(0)
  69. class_colors = [(np.random.randint(255),
  70. np.random.randint(255),
  71. np.random.randint(255)) for _ in range(80)]
  72. save_path = os.path.join(args.path_to_save, mode)
  73. os.makedirs(save_path, exist_ok=True)
  74. # ------------------------- Camera ----------------------------
  75. if mode == 'camera':
  76. print('use camera !!!')
  77. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  78. save_size = (640, 480)
  79. cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  80. save_video_name = os.path.join(save_path, cur_time+'.avi')
  81. fps = 15.0
  82. out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
  83. print(save_video_name)
  84. image_list = []
  85. cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
  86. while True:
  87. ret, frame = cap.read()
  88. if ret:
  89. if cv2.waitKey(1) == ord('q'):
  90. break
  91. orig_h, orig_w, _ = frame.shape
  92. # prepare
  93. x, _, deltas = transform(frame)
  94. x = x.unsqueeze(0).to(device) / 255.
  95. # inference
  96. t0 = time.time()
  97. bboxes, scores, labels = model(x)
  98. t1 = time.time()
  99. print("detection time used ", t1-t0, "s")
  100. # rescale bboxes
  101. origin_img_size = [orig_h, orig_w]
  102. cur_img_size = [*x.shape[-2:]]
  103. bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
  104. # vis detection
  105. frame_vis = visualize(img=frame,
  106. bboxes=bboxes,
  107. scores=scores,
  108. labels=labels,
  109. class_colors=class_colors,
  110. vis_thresh=vis_thresh)
  111. frame_resized = cv2.resize(frame_vis, save_size)
  112. out.write(frame_resized)
  113. if args.gif:
  114. gif_resized = cv2.resize(frame, (640, 480))
  115. gif_resized_rgb = gif_resized[..., (2, 1, 0)]
  116. image_list.append(gif_resized_rgb)
  117. if args.show:
  118. cv2.imshow('detection', frame_resized)
  119. cv2.waitKey(1)
  120. else:
  121. break
  122. cap.release()
  123. out.release()
  124. cv2.destroyAllWindows()
  125. # generate GIF
  126. if args.gif:
  127. save_gif_path = os.path.join(save_path, 'gif_files')
  128. os.makedirs(save_gif_path, exist_ok=True)
  129. save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
  130. print('generating GIF ...')
  131. imageio.mimsave(save_gif_name, image_list, fps=fps)
  132. print('GIF done: {}'.format(save_gif_name))
  133. # ------------------------- Video ---------------------------
  134. elif mode == 'video':
  135. video = cv2.VideoCapture(args.path_to_vid)
  136. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  137. save_size = (640, 480)
  138. cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  139. save_video_name = os.path.join(save_path, cur_time+'.avi')
  140. fps = 15.0
  141. out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
  142. print(save_video_name)
  143. image_list = []
  144. while(True):
  145. ret, frame = video.read()
  146. if ret:
  147. # ------------------------- Detection ---------------------------
  148. orig_h, orig_w, _ = frame.shape
  149. # prepare
  150. x, _, deltas = transform(frame)
  151. x = x.unsqueeze(0).to(device) / 255.
  152. # inference
  153. t0 = time.time()
  154. bboxes, scores, labels = model(x)
  155. t1 = time.time()
  156. print("detection time used ", t1-t0, "s")
  157. # rescale bboxes
  158. origin_img_size = [orig_h, orig_w]
  159. cur_img_size = [*x.shape[-2:]]
  160. bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
  161. # vis detection
  162. frame_vis = visualize(img=frame,
  163. bboxes=bboxes,
  164. scores=scores,
  165. labels=labels,
  166. class_colors=class_colors,
  167. vis_thresh=vis_thresh)
  168. frame_resized = cv2.resize(frame_vis, save_size)
  169. out.write(frame_resized)
  170. if args.gif:
  171. gif_resized = cv2.resize(frame, (640, 480))
  172. gif_resized_rgb = gif_resized[..., (2, 1, 0)]
  173. image_list.append(gif_resized_rgb)
  174. if args.show:
  175. cv2.imshow('detection', frame_resized)
  176. cv2.waitKey(1)
  177. else:
  178. break
  179. video.release()
  180. out.release()
  181. cv2.destroyAllWindows()
  182. # generate GIF
  183. if args.gif:
  184. save_gif_path = os.path.join(save_path, 'gif_files')
  185. os.makedirs(save_gif_path, exist_ok=True)
  186. save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
  187. print('generating GIF ...')
  188. imageio.mimsave(save_gif_name, image_list, fps=fps)
  189. print('GIF done: {}'.format(save_gif_name))
  190. # ------------------------- Image ----------------------------
  191. elif mode == 'image':
  192. for i, img_id in enumerate(os.listdir(args.path_to_img)):
  193. image = cv2.imread((args.path_to_img + '/' + img_id), cv2.IMREAD_COLOR)
  194. orig_h, orig_w, _ = image.shape
  195. # prepare
  196. x, _, deltas = transform(image)
  197. x = x.unsqueeze(0).to(device) / 255.
  198. # inference
  199. t0 = time.time()
  200. bboxes, scores, labels = model(x)
  201. t1 = time.time()
  202. print("detection time used ", t1-t0, "s")
  203. # rescale bboxes
  204. origin_img_size = [orig_h, orig_w]
  205. cur_img_size = [*x.shape[-2:]]
  206. bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
  207. # vis detection
  208. img_processed = visualize(img=image,
  209. bboxes=bboxes,
  210. scores=scores,
  211. labels=labels,
  212. class_colors=class_colors,
  213. vis_thresh=vis_thresh)
  214. cv2.imwrite(os.path.join(save_path, str(i).zfill(6)+'.jpg'), img_processed)
  215. if args.show:
  216. cv2.imshow('detection', img_processed)
  217. cv2.waitKey(0)
  218. def run():
  219. args = parse_args()
  220. # cuda
  221. if args.cuda:
  222. print('use cuda')
  223. device = torch.device("cuda")
  224. else:
  225. device = torch.device("cpu")
  226. # config
  227. model_cfg = build_model_config(args)
  228. trans_cfg = build_trans_config(model_cfg['trans_type'])
  229. # build model
  230. model = build_model(args, model_cfg, device, args.num_classes, False)
  231. # load trained weight
  232. model = load_weight(model, args.weight, args.fuse_conv_bn)
  233. model.to(device).eval()
  234. # transform
  235. val_transform, trans_cfg = build_transform(
  236. args=args, trans_config=trans_cfg, max_stride=model_cfg['max_stride'], is_train=False)
  237. print("================= DETECT =================")
  238. # run
  239. detect(args=args,
  240. model=model,
  241. device=device,
  242. transform=val_transform,
  243. mode=args.mode,
  244. vis_thresh=args.vis_thresh)
  245. if __name__ == '__main__':
  246. run()