demo.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. import cv2
  2. import os
  3. import time
  4. import numpy as np
  5. import imageio
  6. import argparse
  7. from PIL import Image
  8. import torch
  9. # load transform
  10. from datasets import coco_labels, build_transform
  11. # load some utils
  12. from utils.misc import load_weight
  13. from utils.vis_tools import visualize
  14. from config import build_config
  15. from models.detectors import build_model
  16. def parse_args():
  17. parser = argparse.ArgumentParser(description='General Object Detection Demo')
  18. # Basic
  19. parser.add_argument('--mode', default='image',
  20. type=str, help='Use the data from image, video or camera')
  21. parser.add_argument('--cuda', action='store_true', default=False,
  22. help='Use cuda')
  23. parser.add_argument('--path_to_img', default='./dataset/demo/images/',
  24. type=str, help='The path to image files')
  25. parser.add_argument('--path_to_vid', default='dataset/demo/videos/',
  26. type=str, help='The path to video files')
  27. parser.add_argument('--path_to_save', default='det_results/demos/',
  28. type=str, help='The path to save the detection results')
  29. parser.add_argument('-vt', '--visual_threshold', default=0.3, type=float,
  30. help='Final confidence threshold')
  31. parser.add_argument('--show', action='store_true', default=False,
  32. help='show visualization')
  33. parser.add_argument('--gif', action='store_true', default=False,
  34. help='generate gif.')
  35. # Model
  36. parser.add_argument('-m', '--model', default='fcos_r18_1x', type=str,
  37. help='build detector')
  38. parser.add_argument('-nc', '--num_classes', default=80, type=int,
  39. help='number of classes.')
  40. parser.add_argument('--weight', default=None,
  41. type=str, help='Trained state_dict file path to open')
  42. parser.add_argument('-ct', '--conf_thresh', default=0.1, type=float,
  43. help='confidence threshold')
  44. parser.add_argument('-nt', '--nms_thresh', default=0.5, type=float,
  45. help='NMS threshold')
  46. parser.add_argument('--topk', default=100, type=int,
  47. help='topk candidates for testing')
  48. parser.add_argument("--deploy", action="store_true", default=False,
  49. help="deploy mode or not")
  50. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  51. help='fuse Conv & BN')
  52. return parser.parse_args()
  53. def detect(args, model, device, transform, class_names, class_colors):
  54. # path to save
  55. save_path = os.path.join(args.path_to_save, args.mode)
  56. os.makedirs(save_path, exist_ok=True)
  57. # ------------------------- Camera ----------------------------
  58. if args.mode == 'camera':
  59. print('use camera !!!')
  60. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  61. save_size = (640, 480)
  62. cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  63. save_video_name = os.path.join(save_path, cur_time+'.avi')
  64. fps = 15.0
  65. out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
  66. print(save_video_name)
  67. image_list = []
  68. cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
  69. while True:
  70. ret, frame = cap.read()
  71. if ret:
  72. if cv2.waitKey(1) == ord('q'):
  73. break
  74. orig_h, orig_w, _ = frame.shape
  75. # to PIL
  76. image = Image.fromarray(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB))
  77. # prepare
  78. x = transform(image)[0]
  79. x = x.unsqueeze(0).to(device)
  80. # Inference
  81. t0 = time.time()
  82. bboxes, scores, labels = model(x)
  83. print("Infer. time: {}".format(time.time() - t0, "s"))
  84. # Rescale bboxes
  85. bboxes[..., 0::2] *= orig_w
  86. bboxes[..., 1::2] *= orig_h
  87. # vis detection
  88. frame_vis = visualize(frame, bboxes, scores, labels, args.visual_threshold, class_colors, class_names)
  89. frame_resized = cv2.resize(frame_vis, save_size)
  90. out.write(frame_resized)
  91. if args.gif:
  92. gif_resized = cv2.resize(frame, (640, 480))
  93. gif_resized_rgb = gif_resized[..., (2, 1, 0)]
  94. image_list.append(gif_resized_rgb)
  95. if args.show:
  96. cv2.imshow('detection', frame_resized)
  97. cv2.waitKey(1)
  98. else:
  99. break
  100. cap.release()
  101. out.release()
  102. cv2.destroyAllWindows()
  103. # generate GIF
  104. if args.gif:
  105. save_gif_path = os.path.join(save_path, 'gif_files')
  106. os.makedirs(save_gif_path, exist_ok=True)
  107. save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
  108. print('generating GIF ...')
  109. imageio.mimsave(save_gif_name, image_list, fps=fps)
  110. print('GIF done: {}'.format(save_gif_name))
  111. # ------------------------- Video ---------------------------
  112. elif args.mode == 'video':
  113. video = cv2.VideoCapture(args.path_to_vid)
  114. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  115. save_size = (640, 480)
  116. cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  117. save_video_name = os.path.join(save_path, cur_time+'.avi')
  118. fps = 15.0
  119. out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
  120. print(save_video_name)
  121. image_list = []
  122. while(True):
  123. ret, frame = video.read()
  124. if ret:
  125. # ------------------------- Detection ---------------------------
  126. orig_h, orig_w, _ = frame.shape
  127. # to PIL
  128. image = Image.fromarray(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB))
  129. # prepare
  130. x = transform(image)[0]
  131. x = x.unsqueeze(0).to(device)
  132. # Inference
  133. t0 = time.time()
  134. bboxes, scores, labels = model(x)
  135. print("Infer. time: {}".format(time.time() - t0, "s"))
  136. # Rescale bboxes
  137. bboxes[..., 0::2] *= orig_w
  138. bboxes[..., 1::2] *= orig_h
  139. # vis detection
  140. frame_vis = visualize(frame, bboxes, scores, labels, args.visual_threshold, class_colors, class_names)
  141. frame_resized = cv2.resize(frame_vis, save_size)
  142. out.write(frame_resized)
  143. if args.gif:
  144. gif_resized = cv2.resize(frame, (640, 480))
  145. gif_resized_rgb = gif_resized[..., (2, 1, 0)]
  146. image_list.append(gif_resized_rgb)
  147. if args.show:
  148. cv2.imshow('detection', frame_resized)
  149. cv2.waitKey(1)
  150. else:
  151. break
  152. video.release()
  153. out.release()
  154. cv2.destroyAllWindows()
  155. # generate GIF
  156. if args.gif:
  157. save_gif_path = os.path.join(save_path, 'gif_files')
  158. os.makedirs(save_gif_path, exist_ok=True)
  159. save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
  160. print('generating GIF ...')
  161. imageio.mimsave(save_gif_name, image_list, fps=fps)
  162. print('GIF done: {}'.format(save_gif_name))
  163. # ------------------------- Image ----------------------------
  164. elif args.mode == 'image':
  165. for i, img_id in enumerate(os.listdir(args.path_to_img)):
  166. cv2_image = cv2.imread((args.path_to_img + '/' + img_id), cv2.IMREAD_COLOR)
  167. orig_h, orig_w, _ = cv2_image.shape
  168. # to PIL
  169. image = Image.fromarray(cv2.cvtColor(cv2_image,cv2.COLOR_BGR2RGB))
  170. # prepare
  171. x = transform(image)[0]
  172. x = x.unsqueeze(0).to(device)
  173. # Inference
  174. t0 = time.time()
  175. bboxes, scores, labels = model(x)
  176. print("Infer. time: {}".format(time.time() - t0, "s"))
  177. # Rescale bboxes
  178. bboxes[..., 0::2] *= orig_w
  179. bboxes[..., 1::2] *= orig_h
  180. # vis detection
  181. img_processed = visualize(cv2_image, bboxes, scores, labels, args.visual_threshold, class_colors, class_names)
  182. cv2.imwrite(os.path.join(save_path, str(i).zfill(6)+'.jpg'), img_processed)
  183. if args.show:
  184. cv2.imshow('detection', img_processed)
  185. cv2.waitKey(0)
  186. def run():
  187. args = parse_args()
  188. # cuda
  189. if args.cuda:
  190. print('use cuda')
  191. device = torch.device("cuda")
  192. else:
  193. device = torch.device("cpu")
  194. # Dataset & Model Config
  195. cfg = build_config(args)
  196. # Transform
  197. transform = build_transform(cfg, is_train=False)
  198. np.random.seed(0)
  199. class_colors = [(np.random.randint(255),
  200. np.random.randint(255),
  201. np.random.randint(255))
  202. for _ in range(args.num_classes)]
  203. # Model
  204. model = build_model(args, cfg, device, args.num_classes, False)
  205. model = load_weight(model, args.weight, args.fuse_conv_bn)
  206. model.to(device).eval()
  207. print("================= DETECT =================")
  208. # run
  209. detect(args, model, device, transform, coco_labels, class_colors)
  210. if __name__ == '__main__':
  211. run()