demo.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. import argparse
  2. import cv2
  3. import os
  4. import time
  5. import numpy as np
  6. import imageio
  7. import torch
  8. # load transform
  9. from dataset.data_augment import build_transform
  10. # load some utils
  11. from utils.misc import load_weight
  12. from utils.box_ops import rescale_bboxes
  13. from utils.vis_tools import visualize
  14. from models.detectors import build_model
  15. from config import build_model_config, build_trans_config
  16. def parse_args():
  17. parser = argparse.ArgumentParser(description='YOLO Demo')
  18. # basic
  19. parser.add_argument('-size', '--img_size', default=640, type=int,
  20. help='the max size of input image')
  21. parser.add_argument('--mode', default='image',
  22. type=str, help='Use the data from image, video or camera')
  23. parser.add_argument('--cuda', action='store_true', default=False,
  24. help='Use cuda')
  25. parser.add_argument('--path_to_img', default='dataset/demo/images/',
  26. type=str, help='The path to image files')
  27. parser.add_argument('--path_to_vid', default='dataset/demo/videos/',
  28. type=str, help='The path to video files')
  29. parser.add_argument('--path_to_save', default='det_results/demos/',
  30. type=str, help='The path to save the detection results')
  31. parser.add_argument('-vt', '--vis_thresh', default=0.4, type=float,
  32. help='Final confidence threshold for visualization')
  33. parser.add_argument('--show', action='store_true', default=False,
  34. help='show visualization')
  35. parser.add_argument('--gif', action='store_true', default=False,
  36. help='generate gif.')
  37. # model
  38. parser.add_argument('-m', '--model', default='yolov1', type=str,
  39. help='build yolo')
  40. parser.add_argument('-nc', '--num_classes', default=80, type=int,
  41. help='number of classes.')
  42. parser.add_argument('--weight', default=None,
  43. type=str, help='Trained state_dict file path to open')
  44. parser.add_argument('-ct', '--conf_thresh', default=0.1, type=float,
  45. help='confidence threshold')
  46. parser.add_argument('-nt', '--nms_thresh', default=0.5, type=float,
  47. help='NMS threshold')
  48. parser.add_argument('--topk', default=100, type=int,
  49. help='topk candidates for testing')
  50. parser.add_argument("--deploy", action="store_true", default=False,
  51. help="deploy mode or not")
  52. parser.add_argument('--fuse_repconv', action='store_true', default=False,
  53. help='fuse RepConv')
  54. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  55. help='fuse Conv & BN')
  56. return parser.parse_args()
  57. def detect(args,
  58. model,
  59. device,
  60. transform,
  61. vis_thresh,
  62. mode='image'):
  63. # class color
  64. np.random.seed(0)
  65. class_colors = [(np.random.randint(255),
  66. np.random.randint(255),
  67. np.random.randint(255)) for _ in range(80)]
  68. save_path = os.path.join(args.path_to_save, mode)
  69. os.makedirs(save_path, exist_ok=True)
  70. # ------------------------- Camera ----------------------------
  71. if mode == 'camera':
  72. print('use camera !!!')
  73. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  74. save_size = (640, 480)
  75. cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  76. save_video_name = os.path.join(save_path, cur_time+'.avi')
  77. fps = 15.0
  78. out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
  79. print(save_video_name)
  80. image_list = []
  81. cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
  82. while True:
  83. ret, frame = cap.read()
  84. if ret:
  85. if cv2.waitKey(1) == ord('q'):
  86. break
  87. orig_h, orig_w, _ = frame.shape
  88. # prepare
  89. x, _, deltas = transform(frame)
  90. x = x.unsqueeze(0).to(device) / 255.
  91. # inference
  92. t0 = time.time()
  93. bboxes, scores, labels = model(x)
  94. t1 = time.time()
  95. print("detection time used ", t1-t0, "s")
  96. # rescale bboxes
  97. origin_img_size = [orig_h, orig_w]
  98. cur_img_size = [*x.shape[-2:]]
  99. bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
  100. # vis detection
  101. frame_vis = visualize(img=frame,
  102. bboxes=bboxes,
  103. scores=scores,
  104. labels=labels,
  105. class_colors=class_colors,
  106. vis_thresh=vis_thresh)
  107. frame_resized = cv2.resize(frame_vis, save_size)
  108. out.write(frame_resized)
  109. if args.gif:
  110. gif_resized = cv2.resize(frame, (640, 480))
  111. gif_resized_rgb = gif_resized[..., (2, 1, 0)]
  112. image_list.append(gif_resized_rgb)
  113. if args.show:
  114. cv2.imshow('detection', frame_resized)
  115. cv2.waitKey(1)
  116. else:
  117. break
  118. cap.release()
  119. out.release()
  120. cv2.destroyAllWindows()
  121. # generate GIF
  122. if args.gif:
  123. save_gif_path = os.path.join(save_path, 'gif_files')
  124. os.makedirs(save_gif_path, exist_ok=True)
  125. save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
  126. print('generating GIF ...')
  127. imageio.mimsave(save_gif_name, image_list, fps=fps)
  128. print('GIF done: {}'.format(save_gif_name))
  129. # ------------------------- Video ---------------------------
  130. elif mode == 'video':
  131. video = cv2.VideoCapture(args.path_to_vid)
  132. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  133. save_size = (640, 480)
  134. cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  135. save_video_name = os.path.join(save_path, cur_time+'.avi')
  136. fps = 15.0
  137. out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
  138. print(save_video_name)
  139. image_list = []
  140. while(True):
  141. ret, frame = video.read()
  142. if ret:
  143. # ------------------------- Detection ---------------------------
  144. orig_h, orig_w, _ = frame.shape
  145. # prepare
  146. x, _, deltas = transform(frame)
  147. x = x.unsqueeze(0).to(device) / 255.
  148. # inference
  149. t0 = time.time()
  150. bboxes, scores, labels = model(x)
  151. t1 = time.time()
  152. print("detection time used ", t1-t0, "s")
  153. # rescale bboxes
  154. origin_img_size = [orig_h, orig_w]
  155. cur_img_size = [*x.shape[-2:]]
  156. bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
  157. # vis detection
  158. frame_vis = visualize(img=frame,
  159. bboxes=bboxes,
  160. scores=scores,
  161. labels=labels,
  162. class_colors=class_colors,
  163. vis_thresh=vis_thresh)
  164. frame_resized = cv2.resize(frame_vis, save_size)
  165. out.write(frame_resized)
  166. if args.gif:
  167. gif_resized = cv2.resize(frame, (640, 480))
  168. gif_resized_rgb = gif_resized[..., (2, 1, 0)]
  169. image_list.append(gif_resized_rgb)
  170. if args.show:
  171. cv2.imshow('detection', frame_resized)
  172. cv2.waitKey(1)
  173. else:
  174. break
  175. video.release()
  176. out.release()
  177. cv2.destroyAllWindows()
  178. # generate GIF
  179. if args.gif:
  180. save_gif_path = os.path.join(save_path, 'gif_files')
  181. os.makedirs(save_gif_path, exist_ok=True)
  182. save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
  183. print('generating GIF ...')
  184. imageio.mimsave(save_gif_name, image_list, fps=fps)
  185. print('GIF done: {}'.format(save_gif_name))
  186. # ------------------------- Image ----------------------------
  187. elif mode == 'image':
  188. for i, img_id in enumerate(os.listdir(args.path_to_img)):
  189. image = cv2.imread((args.path_to_img + '/' + img_id), cv2.IMREAD_COLOR)
  190. orig_h, orig_w, _ = image.shape
  191. # prepare
  192. x, _, deltas = transform(image)
  193. x = x.unsqueeze(0).to(device) / 255.
  194. # inference
  195. t0 = time.time()
  196. bboxes, scores, labels = model(x)
  197. t1 = time.time()
  198. print("detection time used ", t1-t0, "s")
  199. # rescale bboxes
  200. origin_img_size = [orig_h, orig_w]
  201. cur_img_size = [*x.shape[-2:]]
  202. bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
  203. # vis detection
  204. img_processed = visualize(img=image,
  205. bboxes=bboxes,
  206. scores=scores,
  207. labels=labels,
  208. class_colors=class_colors,
  209. vis_thresh=vis_thresh)
  210. cv2.imwrite(os.path.join(save_path, str(i).zfill(6)+'.jpg'), img_processed)
  211. if args.show:
  212. cv2.imshow('detection', img_processed)
  213. cv2.waitKey(0)
  214. def run():
  215. args = parse_args()
  216. # cuda
  217. if args.cuda:
  218. print('use cuda')
  219. device = torch.device("cuda")
  220. else:
  221. device = torch.device("cpu")
  222. # config
  223. model_cfg = build_model_config(args)
  224. trans_cfg = build_trans_config(model_cfg['trans_type'])
  225. # build model
  226. model = build_model(args, model_cfg, device, args.num_classes, False)
  227. # load trained weight
  228. model = load_weight(model, args.weight, args.fuse_conv_bn)
  229. model.to(device).eval()
  230. # transform
  231. transform = build_transform(args.img_size, trans_cfg, is_train=False)
  232. print("================= DETECT =================")
  233. # run
  234. detect(args=args,
  235. model=model,
  236. device=device,
  237. transform=transform,
  238. mode=args.mode,
  239. vis_thresh=args.vis_thresh)
  240. if __name__ == '__main__':
  241. run()