demo.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import argparse
  2. import cv2
  3. import os
  4. import time
  5. import numpy as np
  6. import imageio
  7. import torch
  8. # load transform
  9. from dataset.build import build_transform
  10. # load some utils
  11. from utils.misc import load_weight
  12. from utils.box_ops import rescale_bboxes
  13. from utils.vis_tools import visualize
  14. from models import build_model
  15. from config import build_config
  16. from dataset.coco import coco_class_labels
  17. def parse_args():
  18. parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
  19. # Basic setting
  20. parser.add_argument('-size', '--img_size', default=640, type=int,
  21. help='the max size of input image')
  22. parser.add_argument('--mode', default='image',
  23. type=str, help='Use the data from image, video or camera')
  24. parser.add_argument('--cuda', action='store_true', default=False,
  25. help='Use cuda')
  26. parser.add_argument('--path_to_img', default='dataset/demo/images/',
  27. type=str, help='The path to image files')
  28. parser.add_argument('--path_to_vid', default='dataset/demo/videos/',
  29. type=str, help='The path to video files')
  30. parser.add_argument('--path_to_save', default='det_results/demos/',
  31. type=str, help='The path to save the detection results')
  32. parser.add_argument('--show', action='store_true', default=False,
  33. help='show visualization')
  34. parser.add_argument('--gif', action='store_true', default=False,
  35. help='generate gif.')
  36. # Model setting
  37. parser.add_argument('-m', '--model', default='yolo_n', type=str,
  38. help='build yolo')
  39. parser.add_argument('-nc', '--num_classes', default=80, type=int,
  40. help='number of classes.')
  41. parser.add_argument('--weight', default=None,
  42. type=str, help='Trained state_dict file path to open')
  43. parser.add_argument("--deploy", action="store_true", default=False,
  44. help="deploy mode or not")
  45. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  46. help='fuse Conv & BN')
  47. # Data setting
  48. parser.add_argument('-d', '--dataset', default='coco',
  49. help='coco, voc, crowdhuman, widerface.')
  50. return parser.parse_args()
  51. def detect(args,
  52. model,
  53. device,
  54. transform,
  55. num_classes,
  56. class_names,
  57. mode='image'):
  58. # class color
  59. np.random.seed(0)
  60. class_colors = [(np.random.randint(255),
  61. np.random.randint(255),
  62. np.random.randint(255)) for _ in range(num_classes)]
  63. save_path = os.path.join(args.path_to_save, mode)
  64. os.makedirs(save_path, exist_ok=True)
  65. # ------------------------- Camera ----------------------------
  66. if mode == 'camera':
  67. print('use camera !!!')
  68. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  69. save_size = (640, 480)
  70. cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  71. save_video_name = os.path.join(save_path, cur_time+'.avi')
  72. fps = 15.0
  73. out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
  74. print(save_video_name)
  75. image_list = []
  76. cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
  77. while True:
  78. ret, frame = cap.read()
  79. if ret:
  80. if cv2.waitKey(1) == ord('q'):
  81. break
  82. orig_h, orig_w, _ = frame.shape
  83. # prepare
  84. x, _, ratio = transform(frame)
  85. x = x.unsqueeze(0).to(device)
  86. # inference
  87. t0 = time.time()
  88. outputs = model(x)
  89. scores = outputs['scores']
  90. labels = outputs['labels']
  91. bboxes = outputs['bboxes']
  92. t1 = time.time()
  93. print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
  94. # rescale bboxes
  95. bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
  96. # vis detection
  97. frame_vis = visualize(image=frame,
  98. bboxes=bboxes,
  99. scores=scores,
  100. labels=labels,
  101. class_colors=class_colors,
  102. class_names=class_names
  103. )
  104. frame_resized = cv2.resize(frame_vis, save_size)
  105. out.write(frame_resized)
  106. if args.gif:
  107. gif_resized = cv2.resize(frame, (640, 480))
  108. gif_resized_rgb = gif_resized[..., (2, 1, 0)]
  109. image_list.append(gif_resized_rgb)
  110. if args.show:
  111. cv2.imshow('detection', frame_resized)
  112. cv2.waitKey(1)
  113. else:
  114. break
  115. cap.release()
  116. out.release()
  117. cv2.destroyAllWindows()
  118. # generate GIF
  119. if args.gif:
  120. save_gif_path = os.path.join(save_path, 'gif_files')
  121. os.makedirs(save_gif_path, exist_ok=True)
  122. save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
  123. print('generating GIF ...')
  124. imageio.mimsave(save_gif_name, image_list, fps=fps)
  125. print('GIF done: {}'.format(save_gif_name))
  126. # ------------------------- Video ---------------------------
  127. elif mode == 'video':
  128. video = cv2.VideoCapture(args.path_to_vid)
  129. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  130. save_size = (640, 480)
  131. cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  132. save_video_name = os.path.join(save_path, cur_time+'.avi')
  133. fps = 15.0
  134. out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
  135. print(save_video_name)
  136. image_list = []
  137. while(True):
  138. ret, frame = video.read()
  139. if ret:
  140. # ------------------------- Detection ---------------------------
  141. orig_h, orig_w, _ = frame.shape
  142. # prepare
  143. x, _, ratio = transform(frame)
  144. x = x.unsqueeze(0).to(device)
  145. # inference
  146. t0 = time.time()
  147. outputs = model(x)
  148. scores = outputs['scores']
  149. labels = outputs['labels']
  150. bboxes = outputs['bboxes']
  151. t1 = time.time()
  152. print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
  153. # rescale bboxes
  154. bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
  155. # vis detection
  156. frame_vis = visualize(image=frame,
  157. bboxes=bboxes,
  158. scores=scores,
  159. labels=labels,
  160. class_colors=class_colors,
  161. class_names=class_names
  162. )
  163. frame_resized = cv2.resize(frame_vis, save_size)
  164. out.write(frame_resized)
  165. if args.gif:
  166. gif_resized = cv2.resize(frame, (640, 480))
  167. gif_resized_rgb = gif_resized[..., (2, 1, 0)]
  168. image_list.append(gif_resized_rgb)
  169. if args.show:
  170. cv2.imshow('detection', frame_resized)
  171. cv2.waitKey(1)
  172. else:
  173. break
  174. video.release()
  175. out.release()
  176. cv2.destroyAllWindows()
  177. # generate GIF
  178. if args.gif:
  179. save_gif_path = os.path.join(save_path, 'gif_files')
  180. os.makedirs(save_gif_path, exist_ok=True)
  181. save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
  182. print('generating GIF ...')
  183. imageio.mimsave(save_gif_name, image_list, fps=fps)
  184. print('GIF done: {}'.format(save_gif_name))
  185. # ------------------------- Image ----------------------------
  186. elif mode == 'image':
  187. for i, img_id in enumerate(os.listdir(args.path_to_img)):
  188. image = cv2.imread((args.path_to_img + '/' + img_id), cv2.IMREAD_COLOR)
  189. orig_h, orig_w, _ = image.shape
  190. # prepare
  191. x, _, ratio = transform(image)
  192. x = x.unsqueeze(0).to(device)
  193. # inference
  194. t0 = time.time()
  195. outputs = model(x)
  196. scores = outputs['scores']
  197. labels = outputs['labels']
  198. bboxes = outputs['bboxes']
  199. t1 = time.time()
  200. print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
  201. # rescale bboxes
  202. bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
  203. # vis detection
  204. img_processed = visualize(image=image,
  205. bboxes=bboxes,
  206. scores=scores,
  207. labels=labels,
  208. class_colors=class_colors,
  209. class_names=class_names
  210. )
  211. cv2.imwrite(os.path.join(save_path, str(i).zfill(6)+'.jpg'), img_processed)
  212. if args.show:
  213. cv2.imshow('detection', img_processed)
  214. cv2.waitKey(0)
  215. def run():
  216. args = parse_args()
  217. # cuda
  218. if args.cuda:
  219. print('use cuda')
  220. device = torch.device("cuda")
  221. else:
  222. device = torch.device("cpu")
  223. # config
  224. cfg = build_config(args)
  225. cfg.num_classes = 80
  226. cfg.class_labels = coco_class_labels
  227. # build model
  228. model = build_model(args, cfg, False)
  229. # load trained weight
  230. model = load_weight(model, args.weight, args.fuse_conv_bn)
  231. model.to(device).eval()
  232. # transform
  233. transform = build_transform(cfg, is_train=False)
  234. print("================= DETECT =================")
  235. # run
  236. detect(args = args,
  237. mode = args.mode,
  238. model = model,
  239. device = device,
  240. transform = transform,
  241. num_classes = cfg.num_classes,
  242. class_names = cfg.class_labels,
  243. )
  244. if __name__ == '__main__':
  245. run()