demo.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. import argparse
  2. import cv2
  3. import os
  4. import time
  5. import numpy as np
  6. import imageio
  7. import torch
  8. # load transform
  9. from dataset.build import build_transform
  10. # load some utils
  11. from utils.misc import load_weight
  12. from utils.box_ops import rescale_bboxes
  13. from utils.vis_tools import visualize
  14. from models.detectors import build_model
  15. from config import build_model_config, build_trans_config, build_dataset_config
  16. def parse_args():
  17. parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
  18. # Basic setting
  19. parser.add_argument('-size', '--img_size', default=640, type=int,
  20. help='the max size of input image')
  21. parser.add_argument('--mosaic', default=None, type=float,
  22. help='mosaic augmentation.')
  23. parser.add_argument('--mixup', default=None, type=float,
  24. help='mixup augmentation.')
  25. parser.add_argument('--mode', default='image',
  26. type=str, help='Use the data from image, video or camera')
  27. parser.add_argument('--cuda', action='store_true', default=False,
  28. help='Use cuda')
  29. parser.add_argument('--path_to_img', default='dataset/demo/images/',
  30. type=str, help='The path to image files')
  31. parser.add_argument('--path_to_vid', default='dataset/demo/videos/',
  32. type=str, help='The path to video files')
  33. parser.add_argument('--path_to_save', default='det_results/demos/',
  34. type=str, help='The path to save the detection results')
  35. parser.add_argument('--show', action='store_true', default=False,
  36. help='show visualization')
  37. parser.add_argument('--gif', action='store_true', default=False,
  38. help='generate gif.')
  39. # Model setting
  40. parser.add_argument('-m', '--model', default='yolov1', type=str,
  41. help='build yolo')
  42. parser.add_argument('-nc', '--num_classes', default=80, type=int,
  43. help='number of classes.')
  44. parser.add_argument('--weight', default=None,
  45. type=str, help='Trained state_dict file path to open')
  46. parser.add_argument('-ct', '--conf_thresh', default=0.35, type=float,
  47. help='confidence threshold')
  48. parser.add_argument('-nt', '--nms_thresh', default=0.5, type=float,
  49. help='NMS threshold')
  50. parser.add_argument('--topk', default=100, type=int,
  51. help='topk candidates dets of each level before NMS')
  52. parser.add_argument("--deploy", action="store_true", default=False,
  53. help="deploy mode or not")
  54. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  55. help='fuse Conv & BN')
  56. parser.add_argument('--no_multi_labels', action='store_true', default=False,
  57. help='Perform post-process with multi-labels trick.')
  58. parser.add_argument('--nms_class_agnostic', action='store_true', default=False,
  59. help='Perform NMS operations regardless of category.')
  60. # Data setting
  61. parser.add_argument('-d', '--dataset', default='coco',
  62. help='coco, voc, crowdhuman, widerface.')
  63. return parser.parse_args()
  64. def detect(args,
  65. model,
  66. device,
  67. transform,
  68. num_classes,
  69. class_names,
  70. class_indexs,
  71. mode='image'):
  72. # class color
  73. np.random.seed(0)
  74. class_colors = [(np.random.randint(255),
  75. np.random.randint(255),
  76. np.random.randint(255)) for _ in range(num_classes)]
  77. save_path = os.path.join(args.path_to_save, mode)
  78. os.makedirs(save_path, exist_ok=True)
  79. # ------------------------- Camera ----------------------------
  80. if mode == 'camera':
  81. print('use camera !!!')
  82. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  83. save_size = (640, 480)
  84. cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  85. save_video_name = os.path.join(save_path, cur_time+'.avi')
  86. fps = 15.0
  87. out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
  88. print(save_video_name)
  89. image_list = []
  90. cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
  91. while True:
  92. ret, frame = cap.read()
  93. if ret:
  94. if cv2.waitKey(1) == ord('q'):
  95. break
  96. orig_h, orig_w, _ = frame.shape
  97. # prepare
  98. x, _, ratio = transform(frame)
  99. x = x.unsqueeze(0).to(device)
  100. # inference
  101. t0 = time.time()
  102. outputs = model(x)
  103. scores = outputs['scores']
  104. labels = outputs['labels']
  105. bboxes = outputs['bboxes']
  106. t1 = time.time()
  107. print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
  108. # rescale bboxes
  109. bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
  110. # vis detection
  111. frame_vis = visualize(image=frame,
  112. bboxes=bboxes,
  113. scores=scores,
  114. labels=labels,
  115. class_colors=class_colors,
  116. class_names=class_names,
  117. class_indexs=class_indexs)
  118. frame_resized = cv2.resize(frame_vis, save_size)
  119. out.write(frame_resized)
  120. if args.gif:
  121. gif_resized = cv2.resize(frame, (640, 480))
  122. gif_resized_rgb = gif_resized[..., (2, 1, 0)]
  123. image_list.append(gif_resized_rgb)
  124. if args.show:
  125. cv2.imshow('detection', frame_resized)
  126. cv2.waitKey(1)
  127. else:
  128. break
  129. cap.release()
  130. out.release()
  131. cv2.destroyAllWindows()
  132. # generate GIF
  133. if args.gif:
  134. save_gif_path = os.path.join(save_path, 'gif_files')
  135. os.makedirs(save_gif_path, exist_ok=True)
  136. save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
  137. print('generating GIF ...')
  138. imageio.mimsave(save_gif_name, image_list, fps=fps)
  139. print('GIF done: {}'.format(save_gif_name))
  140. # ------------------------- Video ---------------------------
  141. elif mode == 'video':
  142. video = cv2.VideoCapture(args.path_to_vid)
  143. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  144. save_size = (640, 480)
  145. cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  146. save_video_name = os.path.join(save_path, cur_time+'.avi')
  147. fps = 15.0
  148. out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
  149. print(save_video_name)
  150. image_list = []
  151. while(True):
  152. ret, frame = video.read()
  153. if ret:
  154. # ------------------------- Detection ---------------------------
  155. orig_h, orig_w, _ = frame.shape
  156. # prepare
  157. x, _, ratio = transform(frame)
  158. x = x.unsqueeze(0).to(device)
  159. # inference
  160. t0 = time.time()
  161. outputs = model(x)
  162. scores = outputs['scores']
  163. labels = outputs['labels']
  164. bboxes = outputs['bboxes']
  165. t1 = time.time()
  166. print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
  167. # rescale bboxes
  168. bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
  169. # vis detection
  170. frame_vis = visualize(image=frame,
  171. bboxes=bboxes,
  172. scores=scores,
  173. labels=labels,
  174. class_colors=class_colors,
  175. class_names=class_names,
  176. class_indexs=class_indexs)
  177. frame_resized = cv2.resize(frame_vis, save_size)
  178. out.write(frame_resized)
  179. if args.gif:
  180. gif_resized = cv2.resize(frame, (640, 480))
  181. gif_resized_rgb = gif_resized[..., (2, 1, 0)]
  182. image_list.append(gif_resized_rgb)
  183. if args.show:
  184. cv2.imshow('detection', frame_resized)
  185. cv2.waitKey(1)
  186. else:
  187. break
  188. video.release()
  189. out.release()
  190. cv2.destroyAllWindows()
  191. # generate GIF
  192. if args.gif:
  193. save_gif_path = os.path.join(save_path, 'gif_files')
  194. os.makedirs(save_gif_path, exist_ok=True)
  195. save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
  196. print('generating GIF ...')
  197. imageio.mimsave(save_gif_name, image_list, fps=fps)
  198. print('GIF done: {}'.format(save_gif_name))
  199. # ------------------------- Image ----------------------------
  200. elif mode == 'image':
  201. for i, img_id in enumerate(os.listdir(args.path_to_img)):
  202. image = cv2.imread((args.path_to_img + '/' + img_id), cv2.IMREAD_COLOR)
  203. orig_h, orig_w, _ = image.shape
  204. # prepare
  205. x, _, ratio = transform(image)
  206. x = x.unsqueeze(0).to(device)
  207. # inference
  208. t0 = time.time()
  209. outputs = model(x)
  210. scores = outputs['scores']
  211. labels = outputs['labels']
  212. bboxes = outputs['bboxes']
  213. t1 = time.time()
  214. print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
  215. # rescale bboxes
  216. bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
  217. # vis detection
  218. img_processed = visualize(image=image,
  219. bboxes=bboxes,
  220. scores=scores,
  221. labels=labels,
  222. class_colors=class_colors,
  223. class_names=class_names,
  224. class_indexs=class_indexs)
  225. cv2.imwrite(os.path.join(save_path, str(i).zfill(6)+'.jpg'), img_processed)
  226. if args.show:
  227. cv2.imshow('detection', img_processed)
  228. cv2.waitKey(0)
  229. def run():
  230. args = parse_args()
  231. # cuda
  232. if args.cuda:
  233. print('use cuda')
  234. device = torch.device("cuda")
  235. else:
  236. device = torch.device("cpu")
  237. # config
  238. model_cfg = build_model_config(args)
  239. trans_cfg = build_trans_config(model_cfg['trans_type'])
  240. data_cfg = build_dataset_config(args)
  241. ## Data info
  242. num_classes = data_cfg['num_classes']
  243. class_names = data_cfg['class_names']
  244. class_indexs = data_cfg['class_indexs']
  245. # build model
  246. model = build_model(args, model_cfg, device, num_classes, False)
  247. # load trained weight
  248. model = load_weight(model, args.weight, args.fuse_conv_bn)
  249. model.to(device).eval()
  250. # transform
  251. val_transform, trans_cfg = build_transform(args, trans_cfg, model_cfg['max_stride'], is_train=False)
  252. print("================= DETECT =================")
  253. # run
  254. detect(args = args,
  255. mode = args.mode,
  256. model = model,
  257. device = device,
  258. transform = val_transform,
  259. num_classes = num_classes,
  260. class_names = class_names,
  261. class_indexs = class_indexs,
  262. )
  263. if __name__ == '__main__':
  264. run()