demo.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. import argparse
  2. import cv2
  3. import os
  4. import time
  5. import numpy as np
  6. import imageio
  7. import torch
  8. # load transform
  9. from dataset.build import build_transform
  10. # load some utils
  11. from utils.misc import load_weight
  12. from utils.box_ops import rescale_bboxes
  13. from utils.vis_tools import visualize
  14. from models import build_model
  15. from config import build_config
  16. from dataset.voc import voc_class_labels
  17. from dataset.coco import coco_class_labels
  18. from yolo.dataset.custom import custom_class_labels
  19. def parse_args():
  20. parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
  21. # Basic setting
  22. parser.add_argument('-size', '--img_size', default=640, type=int,
  23. help='the max size of input image')
  24. parser.add_argument('--mode', default='image',
  25. type=str, help='Use the data from image, video or camera')
  26. parser.add_argument('--cuda', action='store_true', default=False,
  27. help='Use cuda')
  28. parser.add_argument('--path_to_img', default='dataset/demo/images/',
  29. type=str, help='The path to image files')
  30. parser.add_argument('--path_to_vid', default='dataset/demo/videos/',
  31. type=str, help='The path to video files')
  32. parser.add_argument('--path_to_save', default='det_results/demos/',
  33. type=str, help='The path to save the detection results')
  34. parser.add_argument('--show', action='store_true', default=False,
  35. help='show visualization')
  36. parser.add_argument('--gif', action='store_true', default=False,
  37. help='generate gif.')
  38. # Model setting
  39. parser.add_argument('-m', '--model', default='yolo_n', type=str,
  40. help='build yolo')
  41. parser.add_argument('--weight', default=None,
  42. type=str, help='Trained state_dict file path to open')
  43. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  44. help='fuse Conv & BN')
  45. # Data setting
  46. parser.add_argument('-d', '--dataset', default='coco',
  47. help='coco, voc, custom.')
  48. return parser.parse_args()
  49. def detect(args,
  50. model,
  51. device,
  52. transform,
  53. num_classes,
  54. class_names,
  55. mode='image'):
  56. # class color
  57. np.random.seed(0)
  58. class_colors = [(np.random.randint(255),
  59. np.random.randint(255),
  60. np.random.randint(255)) for _ in range(num_classes)]
  61. save_path = os.path.join(args.path_to_save, mode)
  62. os.makedirs(save_path, exist_ok=True)
  63. # ------------------------- Camera ----------------------------
  64. if mode == 'camera':
  65. print('use camera !!!')
  66. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  67. save_size = (640, 480)
  68. cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  69. save_video_name = os.path.join(save_path, cur_time+'.avi')
  70. fps = 15.0
  71. out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
  72. print(save_video_name)
  73. image_list = []
  74. # 笔记本摄像头,index=0;外接摄像头,index=1;
  75. cap = cv2.VideoCapture(index=0, apiPreference=cv2.CAP_DSHOW)
  76. while True:
  77. ret, frame = cap.read()
  78. if ret:
  79. if cv2.waitKey(1) == ord('q'):
  80. break
  81. orig_h, orig_w, _ = frame.shape
  82. # prepare
  83. x, _, ratio = transform(frame)
  84. x = x.unsqueeze(0).to(device)
  85. # inference
  86. t0 = time.time()
  87. outputs = model(x)
  88. scores = outputs['scores']
  89. labels = outputs['labels']
  90. bboxes = outputs['bboxes']
  91. t1 = time.time()
  92. print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
  93. # rescale bboxes
  94. bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
  95. # vis detection
  96. frame_vis = visualize(image=frame,
  97. bboxes=bboxes,
  98. scores=scores,
  99. labels=labels,
  100. class_colors=class_colors,
  101. class_names=class_names
  102. )
  103. frame_resized = cv2.resize(frame_vis, save_size)
  104. out.write(frame_resized)
  105. if args.gif:
  106. gif_resized = cv2.resize(frame, (640, 480))
  107. gif_resized_rgb = gif_resized[..., (2, 1, 0)]
  108. image_list.append(gif_resized_rgb)
  109. if args.show:
  110. cv2.imshow('detection', frame_resized)
  111. cv2.waitKey(1)
  112. else:
  113. break
  114. cap.release()
  115. out.release()
  116. cv2.destroyAllWindows()
  117. # generate GIF
  118. if args.gif:
  119. save_gif_path = os.path.join(save_path, 'gif_files')
  120. os.makedirs(save_gif_path, exist_ok=True)
  121. save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
  122. print('generating GIF ...')
  123. imageio.mimsave(save_gif_name, image_list, fps=fps)
  124. print('GIF done: {}'.format(save_gif_name))
  125. # ------------------------- Video ---------------------------
  126. elif mode == 'video':
  127. video = cv2.VideoCapture(args.path_to_vid)
  128. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  129. save_size = (640, 480)
  130. cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  131. save_video_name = os.path.join(save_path, cur_time+'.avi')
  132. fps = 15.0
  133. out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
  134. print(save_video_name)
  135. image_list = []
  136. while(True):
  137. ret, frame = video.read()
  138. if ret:
  139. # ------------------------- Detection ---------------------------
  140. orig_h, orig_w, _ = frame.shape
  141. # prepare
  142. x, _, ratio = transform(frame)
  143. x = x.unsqueeze(0).to(device)
  144. # inference
  145. t0 = time.time()
  146. outputs = model(x)
  147. scores = outputs['scores']
  148. labels = outputs['labels']
  149. bboxes = outputs['bboxes']
  150. t1 = time.time()
  151. print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
  152. # rescale bboxes
  153. bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
  154. # vis detection
  155. frame_vis = visualize(image=frame,
  156. bboxes=bboxes,
  157. scores=scores,
  158. labels=labels,
  159. class_colors=class_colors,
  160. class_names=class_names
  161. )
  162. frame_resized = cv2.resize(frame_vis, save_size)
  163. out.write(frame_resized)
  164. if args.gif:
  165. gif_resized = cv2.resize(frame, (640, 480))
  166. gif_resized_rgb = gif_resized[..., (2, 1, 0)]
  167. image_list.append(gif_resized_rgb)
  168. if args.show:
  169. cv2.imshow('detection', frame_resized)
  170. cv2.waitKey(1)
  171. else:
  172. break
  173. video.release()
  174. out.release()
  175. cv2.destroyAllWindows()
  176. # generate GIF
  177. if args.gif:
  178. save_gif_path = os.path.join(save_path, 'gif_files')
  179. os.makedirs(save_gif_path, exist_ok=True)
  180. save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
  181. print('generating GIF ...')
  182. imageio.mimsave(save_gif_name, image_list, fps=fps)
  183. print('GIF done: {}'.format(save_gif_name))
  184. # ------------------------- Image ----------------------------
  185. elif mode == 'image':
  186. for i, img_id in enumerate(os.listdir(args.path_to_img)):
  187. image = cv2.imread((args.path_to_img + '/' + img_id), cv2.IMREAD_COLOR)
  188. orig_h, orig_w, _ = image.shape
  189. # prepare
  190. x, _, ratio = transform(image)
  191. x = x.unsqueeze(0).to(device)
  192. # inference
  193. t0 = time.time()
  194. outputs = model(x)
  195. scores = outputs['scores']
  196. labels = outputs['labels']
  197. bboxes = outputs['bboxes']
  198. t1 = time.time()
  199. print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
  200. # rescale bboxes
  201. bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
  202. # vis detection
  203. img_processed = visualize(image=image,
  204. bboxes=bboxes,
  205. scores=scores,
  206. labels=labels,
  207. class_colors=class_colors,
  208. class_names=class_names
  209. )
  210. cv2.imwrite(os.path.join(save_path, str(i).zfill(6)+'.jpg'), img_processed)
  211. if args.show:
  212. cv2.imshow('detection', img_processed)
  213. cv2.waitKey(0)
  214. def run():
  215. args = parse_args()
  216. # Dataset config
  217. if args.dataset == "voc":
  218. cfg.num_classes = 20
  219. cfg.class_labels = voc_class_labels
  220. elif args.dataset == "coco":
  221. cfg.num_classes = 80
  222. cfg.class_labels = coco_class_labels
  223. elif args.dataset == "custom":
  224. cfg.num_classes = len(custom_class_labels)
  225. cfg.class_labels = custom_class_labels
  226. else:
  227. raise NotImplementedError("Unknown dataset: {}".format(args.dataset))
  228. # cuda
  229. if args.cuda and torch.cuda.is_available():
  230. print('use cuda')
  231. device = torch.device("cuda")
  232. else:
  233. device = torch.device("cpu")
  234. # Build config
  235. cfg = build_config(args)
  236. # Build model
  237. model = build_model(args, cfg, False)
  238. # Load trained weight
  239. model = load_weight(model, args.weight, args.fuse_conv_bn)
  240. model.to(device).eval()
  241. # Build transform
  242. transform = build_transform(cfg, is_train=False)
  243. print("================= DETECT =================")
  244. # Run demo
  245. detect(args = args,
  246. mode = args.mode,
  247. model = model,
  248. device = device,
  249. transform = transform,
  250. num_classes = cfg.num_classes,
  251. class_names = cfg.class_labels,
  252. )
  253. if __name__ == '__main__':
  254. run()