track.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. import os
  2. import cv2
  3. import time
  4. import imageio
  5. import argparse
  6. import numpy as np
  7. import torch
  8. from dataset.build import build_transform
  9. from utils.vis_tools import plot_tracking
  10. from utils.misc import load_weight
  11. from utils.box_ops import rescale_bboxes
  12. from config import build_model_config, build_trans_config
  13. from models.detectors import build_model
  14. from models.trackers import build_tracker
  15. os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
  16. IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
  17. def parse_args():
  18. parser = argparse.ArgumentParser(description='Tracking Task')
  19. # basic
  20. parser.add_argument('-size', '--img_size', default=640, type=int,
  21. help='the max size of input image')
  22. parser.add_argument('--cuda', action='store_true', default=False,
  23. help='use cuda.')
  24. # data
  25. parser.add_argument('--mode', type=str, default='image',
  26. help='image, video or camera')
  27. parser.add_argument('--path_to_img', type=str, default='dataset/demo/images/',
  28. help='Dir to load images')
  29. parser.add_argument('--path_to_vid', type=str, default='dataset/demo/videos/',
  30. help='Dir to load a video')
  31. parser.add_argument('--path_to_save', default='det_results/', type=str,
  32. help='Dir to save results')
  33. parser.add_argument('--fps', type=int, default=30,
  34. help='frame rate')
  35. parser.add_argument('--show', action='store_true', default=False,
  36. help='show results.')
  37. parser.add_argument('--save', action='store_true', default=False,
  38. help='save results.')
  39. parser.add_argument('--gif', action='store_true', default=False,
  40. help='generate gif.')
  41. # tracker
  42. parser.add_argument('-tk', '--tracker', default='byte_tracker', type=str,
  43. help='build FreeTrack')
  44. parser.add_argument("--track_thresh", type=float, default=0.4,
  45. help="tracking confidence threshold")
  46. parser.add_argument("--track_buffer", type=int, default=30,
  47. help="the frames for keep lost tracks")
  48. parser.add_argument("--match_thresh", type=float, default=0.8,
  49. help="matching threshold for tracking")
  50. parser.add_argument("--aspect_ratio_thresh", type=float, default=1.6,
  51. help="threshold for filtering out boxes of which \
  52. aspect ratio are above the given value.")
  53. parser.add_argument('--min_box_area', type=float, default=10,
  54. help='filter out tiny boxes')
  55. parser.add_argument("--mot20", default=False, action="store_true",
  56. help="test mot20.")
  57. # detector
  58. parser.add_argument('-dt', '--model', default='yolov1', type=str,
  59. help='build YOLO')
  60. parser.add_argument('-ns', '--num_classes', type=int, default=80,
  61. help='number of object classes.')
  62. parser.add_argument('--weight', default=None,
  63. type=str, help='Trained state_dict file path to open')
  64. parser.add_argument('-ct', '--conf_thresh', default=0.3, type=float,
  65. help='confidence threshold')
  66. parser.add_argument('-nt', '--nms_thresh', default=0.5, type=float,
  67. help='NMS threshold')
  68. parser.add_argument('--topk', default=100, type=int,
  69. help='topk candidates for testing')
  70. parser.add_argument('-fcb', '--fuse_conv_bn', action='store_true', default=False,
  71. help='fuse Conv & BN')
  72. return parser.parse_args()
  73. def get_image_list(path):
  74. image_names = []
  75. for maindir, subdir, file_name_list in os.walk(path):
  76. for filename in file_name_list:
  77. apath = os.path.join(maindir, filename)
  78. ext = os.path.splitext(apath)[1]
  79. if ext in IMAGE_EXT:
  80. image_names.append(apath)
  81. return image_names
  82. def run(args,
  83. tracker,
  84. detector,
  85. device,
  86. transform):
  87. save_path = os.path.join(args.path_to_save, 'tracking', args.mode)
  88. os.makedirs(save_path, exist_ok=True)
  89. # ------------------------- Camera ----------------------------
  90. if args.mode == 'camera':
  91. print('use camera !!!')
  92. # Launch camera
  93. cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
  94. frame_id = 0
  95. results = []
  96. # For saving
  97. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  98. save_size = (640, 480)
  99. cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  100. save_video_name = os.path.join(save_path, cur_time+'.avi')
  101. fps = 15.0
  102. out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
  103. print(save_video_name)
  104. image_list = []
  105. # start tracking
  106. while True:
  107. ret, frame = cap.read()
  108. if ret:
  109. if cv2.waitKey(1) == ord('q'):
  110. break
  111. # ------------------------- Detection ---------------------------
  112. # preprocess
  113. x, _, deltas = transform(frame)
  114. x = x.unsqueeze(0).to(device) / 255.
  115. orig_h, orig_w, _ = frame.shape
  116. # detect
  117. t0 = time.time()
  118. bboxes, scores, labels = detector(x)
  119. print("=============== Frame-{} ================".format(frame_id))
  120. print("detect time: {:.1f} ms".format((time.time() - t0)*1000))
  121. # rescale bboxes
  122. origin_img_size = [orig_h, orig_w]
  123. cur_img_size = [*x.shape[-2:]]
  124. bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
  125. # track
  126. t2 = time.time()
  127. if len(bboxes) > 0:
  128. online_targets = tracker.update(scores, bboxes, labels)
  129. online_xywhs = []
  130. online_ids = []
  131. online_scores = []
  132. for t in online_targets:
  133. xywh = t.xywh
  134. tid = t.track_id
  135. vertical = xywh[2] / xywh[3] > args.aspect_ratio_thresh
  136. if xywh[2] * xywh[3] > args.min_box_area and not vertical:
  137. online_xywhs.append(xywh)
  138. online_ids.append(tid)
  139. online_scores.append(t.score)
  140. results.append(
  141. f"{frame_id},{tid},{xywh[0]:.2f},{xywh[1]:.2f},{xywh[2]:.2f},{xywh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
  142. )
  143. print("tracking time: {:.1f} ms".format((time.time() - t2)*1000))
  144. # plot tracking results
  145. online_im = plot_tracking(
  146. frame, online_xywhs, online_ids, frame_id=frame_id + 1, fps=1. / (time.time() - t0)
  147. )
  148. else:
  149. online_im = frame
  150. frame_resized = cv2.resize(online_im, save_size)
  151. out.write(frame_resized)
  152. if args.gif:
  153. gif_resized = cv2.resize(online_im, (640, 480))
  154. gif_resized_rgb = gif_resized[..., (2, 1, 0)]
  155. image_list.append(gif_resized_rgb)
  156. # show results
  157. if args.show:
  158. cv2.imshow('tracking', online_im)
  159. ch = cv2.waitKey(1)
  160. if ch == 27 or ch == ord("q") or ch == ord("Q"):
  161. break
  162. else:
  163. break
  164. frame_id += 1
  165. cap.release()
  166. out.release()
  167. cv2.destroyAllWindows()
  168. # generate GIF
  169. if args.gif:
  170. save_gif_path = os.path.join(save_path, 'gif_files')
  171. os.makedirs(save_gif_path, exist_ok=True)
  172. save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
  173. print('generating GIF ...')
  174. imageio.mimsave(save_gif_name, image_list, fps=fps)
  175. print('GIF done: {}'.format(save_gif_name))
  176. # ------------------------- Video ---------------------------
  177. elif args.mode == 'video':
  178. # read a video
  179. video = cv2.VideoCapture(args.path_to_vid)
  180. fps = video.get(cv2.CAP_PROP_FPS)
  181. # For saving
  182. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  183. save_size = (640, 480)
  184. cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  185. save_video_name = os.path.join(save_path, cur_time+'.avi')
  186. out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
  187. print(save_video_name)
  188. image_list = []
  189. # start tracking
  190. frame_id = 0
  191. results = []
  192. while(True):
  193. ret, frame = video.read()
  194. if ret:
  195. # ------------------------- Detection ---------------------------
  196. # preprocess
  197. x, _, deltas = transform(frame)
  198. x = x.unsqueeze(0).to(device) / 255.
  199. orig_h, orig_w, _ = frame.shape
  200. # detect
  201. t0 = time.time()
  202. bboxes, scores, labels = detector(x)
  203. print("=============== Frame-{} ================".format(frame_id))
  204. print("detect time: {:.1f} ms".format((time.time() - t0)*1000))
  205. # rescale bboxes
  206. origin_img_size = [orig_h, orig_w]
  207. cur_img_size = [*x.shape[-2:]]
  208. bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
  209. # track
  210. t2 = time.time()
  211. if len(bboxes) > 0:
  212. online_targets = tracker.update(scores, bboxes, labels)
  213. online_xywhs = []
  214. online_ids = []
  215. online_scores = []
  216. for t in online_targets:
  217. xywh = t.xywh
  218. tid = t.track_id
  219. vertical = xywh[2] / xywh[3] > args.aspect_ratio_thresh
  220. if xywh[2] * xywh[3] > args.min_box_area and not vertical:
  221. online_xywhs.append(xywh)
  222. online_ids.append(tid)
  223. online_scores.append(t.score)
  224. results.append(
  225. f"{frame_id},{tid},{xywh[0]:.2f},{xywh[1]:.2f},{xywh[2]:.2f},{xywh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
  226. )
  227. print("tracking time: {:.1f} ms".format((time.time() - t2)*1000))
  228. # plot tracking results
  229. online_im = plot_tracking(
  230. frame, online_xywhs, online_ids, frame_id=frame_id + 1, fps=1. / (time.time() - t0)
  231. )
  232. else:
  233. online_im = frame
  234. frame_resized = cv2.resize(online_im, save_size)
  235. out.write(frame_resized)
  236. if args.gif:
  237. gif_resized = cv2.resize(online_im, (640, 480))
  238. gif_resized_rgb = gif_resized[..., (2, 1, 0)]
  239. image_list.append(gif_resized_rgb)
  240. # show results
  241. if args.show:
  242. cv2.imshow('tracking', online_im)
  243. ch = cv2.waitKey(1)
  244. if ch == 27 or ch == ord("q") or ch == ord("Q"):
  245. break
  246. else:
  247. break
  248. frame_id += 1
  249. video.release()
  250. out.release()
  251. cv2.destroyAllWindows()
  252. # generate GIF
  253. if args.gif:
  254. save_gif_path = os.path.join(save_path, 'gif_files')
  255. os.makedirs(save_gif_path, exist_ok=True)
  256. save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
  257. print('generating GIF ...')
  258. imageio.mimsave(save_gif_name, image_list, fps=fps)
  259. print('GIF done: {}'.format(save_gif_name))
  260. # ------------------------- Image ----------------------------
  261. elif args.mode == 'image':
  262. files = get_image_list(args.path_to_img)
  263. files.sort()
  264. # For saving
  265. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  266. save_size = (640, 480)
  267. cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  268. save_video_name = os.path.join(save_path, cur_time+'.avi')
  269. out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
  270. print(save_video_name)
  271. image_list = []
  272. # start tracking
  273. frame_id = 0
  274. results = []
  275. for frame_id, img_path in enumerate(files, 1):
  276. image = cv2.imread(os.path.join(img_path))
  277. # preprocess
  278. x, _, deltas = transform(image)
  279. x = x.unsqueeze(0).to(device) / 255.
  280. orig_h, orig_w, _ = image.shape
  281. # detect
  282. t0 = time.time()
  283. bboxes, scores, labels = detector(x)
  284. print("=============== Frame-{} ================".format(frame_id))
  285. print("detect time: {:.1f} ms".format((time.time() - t0)*1000))
  286. # rescale bboxes
  287. origin_img_size = [orig_h, orig_w]
  288. cur_img_size = [*x.shape[-2:]]
  289. bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
  290. # track
  291. t2 = time.time()
  292. if len(bboxes) > 0:
  293. online_targets = tracker.update(scores, bboxes, labels)
  294. online_xywhs = []
  295. online_ids = []
  296. online_scores = []
  297. for t in online_targets:
  298. xywh = t.xywh
  299. tid = t.track_id
  300. vertical = xywh[2] / xywh[3] > args.aspect_ratio_thresh
  301. if xywh[2] * xywh[3] > args.min_box_area and not vertical:
  302. online_xywhs.append(xywh)
  303. online_ids.append(tid)
  304. online_scores.append(t.score)
  305. results.append(
  306. f"{frame_id},{tid},{xywh[0]:.2f},{xywh[1]:.2f},{xywh[2]:.2f},{xywh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
  307. )
  308. print("tracking time: {:.1f} ms".format((time.time() - t2)*1000))
  309. # plot tracking results
  310. online_im = plot_tracking(
  311. image, online_xywhs, online_ids, frame_id=frame_id + 1, fps=1. / (time.time() - t0)
  312. )
  313. else:
  314. online_im = frame
  315. frame_resized = cv2.resize(online_im, save_size)
  316. out.write(frame_resized)
  317. if args.gif:
  318. gif_resized = cv2.resize(online_im, (640, 480))
  319. gif_resized_rgb = gif_resized[..., (2, 1, 0)]
  320. image_list.append(gif_resized_rgb)
  321. # show results
  322. if args.show:
  323. cv2.imshow('tracking', online_im)
  324. ch = cv2.waitKey(1)
  325. if ch == 27 or ch == ord("q") or ch == ord("Q"):
  326. break
  327. frame_id += 1
  328. cv2.destroyAllWindows()
  329. out.release()
  330. cv2.destroyAllWindows()
  331. # generate GIF
  332. if args.gif:
  333. save_gif_path = os.path.join(save_path, 'gif_files')
  334. os.makedirs(save_gif_path, exist_ok=True)
  335. save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
  336. print('generating GIF ...')
  337. imageio.mimsave(save_gif_name, image_list, fps=fps)
  338. print('GIF done: {}'.format(save_gif_name))
  339. if __name__ == '__main__':
  340. args = parse_args()
  341. # cuda
  342. if args.cuda:
  343. print('use cuda')
  344. device = torch.device("cuda")
  345. else:
  346. device = torch.device("cpu")
  347. np.random.seed(0)
  348. # config
  349. model_cfg = build_model_config(args)
  350. trans_cfg = build_trans_config(model_cfg['trans_type'])
  351. # transform
  352. transform = build_transform(args.img_size, trans_cfg, is_train=False)
  353. # ---------------------- General Object Detector ----------------------
  354. detector = build_model(args, model_cfg, device, args.num_classes, False)
  355. ## load trained weight
  356. detector = load_weight(detector, args.weight, args.fuse_conv_bn)
  357. detector.to(device).eval()
  358. # ---------------------- General Object Tracker ----------------------
  359. tracker = build_tracker(args)
  360. # run
  361. run(args=args,
  362. tracker=tracker,
  363. detector=detector,
  364. device=device,
  365. transform=transform)