demo.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import argparse
  2. import cv2
  3. import os
  4. import time
  5. import numpy as np
  6. import imageio
  7. import torch
  8. # load transform
  9. from dataset.build import build_transform
  10. # load some utils
  11. from utils.misc import load_weight
  12. from utils.box_ops import rescale_bboxes
  13. from utils.vis_tools import visualize
  14. from models.detectors import build_model
  15. from config import build_model_config, build_trans_config, build_dataset_config
  16. def parse_args():
  17. parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
  18. # Basic setting
  19. parser.add_argument('-size', '--img_size', default=640, type=int,
  20. help='the max size of input image')
  21. parser.add_argument('--mosaic', default=None, type=float,
  22. help='mosaic augmentation.')
  23. parser.add_argument('--mixup', default=None, type=float,
  24. help='mixup augmentation.')
  25. parser.add_argument('--mode', default='image',
  26. type=str, help='Use the data from image, video or camera')
  27. parser.add_argument('--cuda', action='store_true', default=False,
  28. help='Use cuda')
  29. parser.add_argument('--path_to_img', default='dataset/demo/images/',
  30. type=str, help='The path to image files')
  31. parser.add_argument('--path_to_vid', default='dataset/demo/videos/',
  32. type=str, help='The path to video files')
  33. parser.add_argument('--path_to_save', default='det_results/demos/',
  34. type=str, help='The path to save the detection results')
  35. parser.add_argument('--show', action='store_true', default=False,
  36. help='show visualization')
  37. parser.add_argument('--gif', action='store_true', default=False,
  38. help='generate gif.')
  39. # Model setting
  40. parser.add_argument('-m', '--model', default='yolov1', type=str,
  41. help='build yolo')
  42. parser.add_argument('-nc', '--num_classes', default=80, type=int,
  43. help='number of classes.')
  44. parser.add_argument('--weight', default=None,
  45. type=str, help='Trained state_dict file path to open')
  46. parser.add_argument('-ct', '--conf_thresh', default=0.35, type=float,
  47. help='confidence threshold')
  48. parser.add_argument('-nt', '--nms_thresh', default=0.5, type=float,
  49. help='NMS threshold')
  50. parser.add_argument('--topk', default=100, type=int,
  51. help='topk candidates dets of each level before NMS')
  52. parser.add_argument("--deploy", action="store_true", default=False,
  53. help="deploy mode or not")
  54. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  55. help='fuse Conv & BN')
  56. parser.add_argument('--no_multi_labels', action='store_true', default=False,
  57. help='Perform post-process with multi-labels trick.')
  58. parser.add_argument('--nms_class_agnostic', action='store_true', default=False,
  59. help='Perform NMS operations regardless of category.')
  60. # Data setting
  61. parser.add_argument('-d', '--dataset', default='coco',
  62. help='coco, voc, crowdhuman, widerface.')
  63. return parser.parse_args()
  64. def detect(args,
  65. model,
  66. device,
  67. transform,
  68. num_classes,
  69. class_names,
  70. class_indexs,
  71. mode='image'):
  72. # class color
  73. np.random.seed(0)
  74. class_colors = [(np.random.randint(255),
  75. np.random.randint(255),
  76. np.random.randint(255)) for _ in range(num_classes)]
  77. save_path = os.path.join(args.path_to_save, mode)
  78. os.makedirs(save_path, exist_ok=True)
  79. # ------------------------- Camera ----------------------------
  80. if mode == 'camera':
  81. print('use camera !!!')
  82. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  83. save_size = (640, 480)
  84. cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  85. save_video_name = os.path.join(save_path, cur_time+'.avi')
  86. fps = 15.0
  87. out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
  88. print(save_video_name)
  89. image_list = []
  90. cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
  91. while True:
  92. ret, frame = cap.read()
  93. if ret:
  94. if cv2.waitKey(1) == ord('q'):
  95. break
  96. orig_h, orig_w, _ = frame.shape
  97. # prepare
  98. x, _, ratio = transform(frame)
  99. x = x.unsqueeze(0).to(device) / 255.
  100. # inference
  101. t0 = time.time()
  102. bboxes, scores, labels = model(x)
  103. t1 = time.time()
  104. print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
  105. # rescale bboxes
  106. bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
  107. # vis detection
  108. frame_vis = visualize(image=frame,
  109. bboxes=bboxes,
  110. scores=scores,
  111. labels=labels,
  112. class_colors=class_colors,
  113. class_names=class_names,
  114. class_indexs=class_indexs)
  115. frame_resized = cv2.resize(frame_vis, save_size)
  116. out.write(frame_resized)
  117. if args.gif:
  118. gif_resized = cv2.resize(frame, (640, 480))
  119. gif_resized_rgb = gif_resized[..., (2, 1, 0)]
  120. image_list.append(gif_resized_rgb)
  121. if args.show:
  122. cv2.imshow('detection', frame_resized)
  123. cv2.waitKey(1)
  124. else:
  125. break
  126. cap.release()
  127. out.release()
  128. cv2.destroyAllWindows()
  129. # generate GIF
  130. if args.gif:
  131. save_gif_path = os.path.join(save_path, 'gif_files')
  132. os.makedirs(save_gif_path, exist_ok=True)
  133. save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
  134. print('generating GIF ...')
  135. imageio.mimsave(save_gif_name, image_list, fps=fps)
  136. print('GIF done: {}'.format(save_gif_name))
  137. # ------------------------- Video ---------------------------
  138. elif mode == 'video':
  139. video = cv2.VideoCapture(args.path_to_vid)
  140. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  141. save_size = (640, 480)
  142. cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  143. save_video_name = os.path.join(save_path, cur_time+'.avi')
  144. fps = 15.0
  145. out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
  146. print(save_video_name)
  147. image_list = []
  148. while(True):
  149. ret, frame = video.read()
  150. if ret:
  151. # ------------------------- Detection ---------------------------
  152. orig_h, orig_w, _ = frame.shape
  153. # prepare
  154. x, _, ratio = transform(frame)
  155. x = x.unsqueeze(0).to(device) / 255.
  156. # inference
  157. t0 = time.time()
  158. bboxes, scores, labels = model(x)
  159. t1 = time.time()
  160. print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
  161. # rescale bboxes
  162. bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
  163. # vis detection
  164. frame_vis = visualize(image=frame,
  165. bboxes=bboxes,
  166. scores=scores,
  167. labels=labels,
  168. class_colors=class_colors,
  169. class_names=class_names,
  170. class_indexs=class_indexs)
  171. frame_resized = cv2.resize(frame_vis, save_size)
  172. out.write(frame_resized)
  173. if args.gif:
  174. gif_resized = cv2.resize(frame, (640, 480))
  175. gif_resized_rgb = gif_resized[..., (2, 1, 0)]
  176. image_list.append(gif_resized_rgb)
  177. if args.show:
  178. cv2.imshow('detection', frame_resized)
  179. cv2.waitKey(1)
  180. else:
  181. break
  182. video.release()
  183. out.release()
  184. cv2.destroyAllWindows()
  185. # generate GIF
  186. if args.gif:
  187. save_gif_path = os.path.join(save_path, 'gif_files')
  188. os.makedirs(save_gif_path, exist_ok=True)
  189. save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
  190. print('generating GIF ...')
  191. imageio.mimsave(save_gif_name, image_list, fps=fps)
  192. print('GIF done: {}'.format(save_gif_name))
  193. # ------------------------- Image ----------------------------
  194. elif mode == 'image':
  195. for i, img_id in enumerate(os.listdir(args.path_to_img)):
  196. image = cv2.imread((args.path_to_img + '/' + img_id), cv2.IMREAD_COLOR)
  197. orig_h, orig_w, _ = image.shape
  198. # prepare
  199. x, _, ratio = transform(image)
  200. x = x.unsqueeze(0).to(device) / 255.
  201. # inference
  202. t0 = time.time()
  203. bboxes, scores, labels = model(x)
  204. t1 = time.time()
  205. print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
  206. # rescale bboxes
  207. bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
  208. # vis detection
  209. img_processed = visualize(image=image,
  210. bboxes=bboxes,
  211. scores=scores,
  212. labels=labels,
  213. class_colors=class_colors,
  214. class_names=class_names,
  215. class_indexs=class_indexs)
  216. cv2.imwrite(os.path.join(save_path, str(i).zfill(6)+'.jpg'), img_processed)
  217. if args.show:
  218. cv2.imshow('detection', img_processed)
  219. cv2.waitKey(0)
  220. def run():
  221. args = parse_args()
  222. # cuda
  223. if args.cuda:
  224. print('use cuda')
  225. device = torch.device("cuda")
  226. else:
  227. device = torch.device("cpu")
  228. # config
  229. model_cfg = build_model_config(args)
  230. trans_cfg = build_trans_config(model_cfg['trans_type'])
  231. data_cfg = build_dataset_config(args)
  232. ## Data info
  233. num_classes = data_cfg['num_classes']
  234. class_names = data_cfg['class_names']
  235. class_indexs = data_cfg['class_indexs']
  236. # build model
  237. model = build_model(args, model_cfg, device, args.num_classes, False)
  238. # load trained weight
  239. model = load_weight(model, args.weight, args.fuse_conv_bn)
  240. model.to(device).eval()
  241. # transform
  242. val_transform, trans_cfg = build_transform(args, trans_cfg, model_cfg['max_stride'], is_train=False)
  243. print("================= DETECT =================")
  244. # run
  245. detect(args = args,
  246. mode = args.mode,
  247. model = model,
  248. device = device,
  249. transform = val_transform,
  250. num_classes = num_classes,
  251. class_names = class_names,
  252. class_indexs = class_indexs,
  253. )
  254. if __name__ == '__main__':
  255. run()