track.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. import os
  2. import cv2
  3. import time
  4. import argparse
  5. import numpy as np
  6. import torch
  7. from dataset.data_augment import build_transform
  8. from utils.vis_tools import plot_tracking
  9. from utils.misc import load_weight
  10. from utils.box_ops import rescale_bboxes
  11. from config import build_model_config, build_trans_config
  12. from models.detectors import build_model
  13. from models.tracker import build_tracker
  14. os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
  15. IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
  16. def parse_args():
  17. parser = argparse.ArgumentParser(description='Tracking Task')
  18. # basic
  19. parser.add_argument('-size', '--img_size', default=640, type=int,
  20. help='the max size of input image')
  21. parser.add_argument('--cuda', action='store_true', default=False,
  22. help='use cuda.')
  23. # data
  24. parser.add_argument('--mode', type=str, default='image',
  25. help='image, video or camera')
  26. parser.add_argument('--path_to_img', type=str, default='dataset/demo/images/',
  27. help='Dir to load images')
  28. parser.add_argument('--path_to_vid', type=str, default='dataset/demo/videos/',
  29. help='Dir to load a video')
  30. parser.add_argument('--path_to_save', default='det_results/', type=str,
  31. help='Dir to save results')
  32. parser.add_argument('--fps', type=int, default=30,
  33. help='frame rate')
  34. parser.add_argument('--show', action='store_true', default=False,
  35. help='show results.')
  36. parser.add_argument('--save', action='store_true', default=False,
  37. help='save results.')
  38. # tracker
  39. parser.add_argument('-tk', '--tracker', default='byte_tracker', type=str,
  40. help='build FreeTrack')
  41. parser.add_argument("--track_thresh", type=float, default=0.5,
  42. help="tracking confidence threshold")
  43. parser.add_argument("--track_buffer", type=int, default=30,
  44. help="the frames for keep lost tracks")
  45. parser.add_argument("--match_thresh", type=float, default=0.8,
  46. help="matching threshold for tracking")
  47. parser.add_argument("--aspect_ratio_thresh", type=float, default=1.6,
  48. help="threshold for filtering out boxes of which \
  49. aspect ratio are above the given value.")
  50. parser.add_argument('--min_box_area', type=float, default=10,
  51. help='filter out tiny boxes')
  52. parser.add_argument("--mot20", default=False, action="store_true",
  53. help="test mot20.")
  54. # detector
  55. parser.add_argument('-dt', '--model', default='yolov1', type=str,
  56. help='build YOLO')
  57. parser.add_argument('-ns', '--num_classes', type=int, default=80,
  58. help='number of object classes.')
  59. parser.add_argument('--weight', default=None,
  60. type=str, help='Trained state_dict file path to open')
  61. parser.add_argument('-ct', '--conf_thresh', default=0.3, type=float,
  62. help='confidence threshold')
  63. parser.add_argument('-nt', '--nms_thresh', default=0.5, type=float,
  64. help='NMS threshold')
  65. parser.add_argument('--topk', default=100, type=int,
  66. help='topk candidates for testing')
  67. parser.add_argument('-fcb', '--fuse_conv_bn', action='store_true', default=False,
  68. help='fuse Conv & BN')
  69. return parser.parse_args()
  70. def get_image_list(path):
  71. image_names = []
  72. for maindir, subdir, file_name_list in os.walk(path):
  73. for filename in file_name_list:
  74. apath = os.path.join(maindir, filename)
  75. ext = os.path.splitext(apath)[1]
  76. if ext in IMAGE_EXT:
  77. image_names.append(apath)
  78. return image_names
  79. def run(args,
  80. tracker,
  81. detector,
  82. device,
  83. transform):
  84. save_path = os.path.join(args.path_to_save, args.mode)
  85. os.makedirs(save_path, exist_ok=True)
  86. # ------------------------- Camera ----------------------------
  87. if args.mode == 'camera':
  88. print('use camera !!!')
  89. cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
  90. frame_id = 0
  91. results = []
  92. # start tracking
  93. while True:
  94. ret, frame = cap.read()
  95. if ret:
  96. if cv2.waitKey(1) == ord('q'):
  97. break
  98. # ------------------------- Detection ---------------------------
  99. # preprocess
  100. x, _, deltas = transform(frame)
  101. x = x.unsqueeze(0).to(device) / 255.
  102. orig_h, orig_w, _ = frame.shape
  103. # detect
  104. t0 = time.time()
  105. bboxes, scores, labels = detector(x)
  106. print("=============== Frame-{} ================".format(frame_id))
  107. print("detect time: {:.1f} ms".format((time.time() - t0)*1000))
  108. # rescale bboxes
  109. origin_img_size = [orig_h, orig_w]
  110. cur_img_size = [*x.shape[-2:]]
  111. bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
  112. # track
  113. t2 = time.time()
  114. if len(bboxes) > 0:
  115. online_targets = tracker.update(scores, bboxes, labels)
  116. online_xywhs = []
  117. online_ids = []
  118. online_scores = []
  119. for t in online_targets:
  120. xywh = t.xywh
  121. tid = t.track_id
  122. vertical = xywh[2] / xywh[3] > args.aspect_ratio_thresh
  123. if xywh[2] * xywh[3] > args.min_box_area and not vertical:
  124. online_xywhs.append(xywh)
  125. online_ids.append(tid)
  126. online_scores.append(t.score)
  127. results.append(
  128. f"{frame_id},{tid},{xywh[0]:.2f},{xywh[1]:.2f},{xywh[2]:.2f},{xywh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
  129. )
  130. print("tracking time: {:.1f} ms".format((time.time() - t2)*1000))
  131. # plot tracking results
  132. online_im = plot_tracking(
  133. frame, online_xywhs, online_ids, frame_id=frame_id + 1, fps=1. / (time.time() - t0)
  134. )
  135. else:
  136. online_im = frame
  137. # show results
  138. if args.show:
  139. cv2.imshow('tracking', online_im)
  140. ch = cv2.waitKey(1)
  141. if ch == 27 or ch == ord("q") or ch == ord("Q"):
  142. break
  143. else:
  144. break
  145. frame_id += 1
  146. cap.release()
  147. cv2.destroyAllWindows()
  148. # ------------------------- Image ----------------------------
  149. elif args.mode == 'image':
  150. files = get_image_list(args.path_to_img)
  151. files.sort()
  152. # start tracking
  153. frame_id = 0
  154. results = []
  155. for frame_id, img_path in enumerate(files, 1):
  156. image = cv2.imread(os.path.join(img_path))
  157. # preprocess
  158. x, _, deltas = transform(image)
  159. x = x.unsqueeze(0).to(device) / 255.
  160. orig_h, orig_w, _ = image.shape
  161. # detect
  162. t0 = time.time()
  163. bboxes, scores, labels = detector(x)
  164. print("=============== Frame-{} ================".format(frame_id))
  165. print("detect time: {:.1f} ms".format((time.time() - t0)*1000))
  166. # rescale bboxes
  167. origin_img_size = [orig_h, orig_w]
  168. cur_img_size = [*x.shape[-2:]]
  169. bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
  170. # track
  171. t2 = time.time()
  172. if len(bboxes) > 0:
  173. online_targets = tracker.update(scores, bboxes, labels)
  174. online_xywhs = []
  175. online_ids = []
  176. online_scores = []
  177. for t in online_targets:
  178. xywh = t.xywh
  179. tid = t.track_id
  180. vertical = xywh[2] / xywh[3] > args.aspect_ratio_thresh
  181. if xywh[2] * xywh[3] > args.min_box_area and not vertical:
  182. online_xywhs.append(xywh)
  183. online_ids.append(tid)
  184. online_scores.append(t.score)
  185. results.append(
  186. f"{frame_id},{tid},{xywh[0]:.2f},{xywh[1]:.2f},{xywh[2]:.2f},{xywh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
  187. )
  188. print("tracking time: {:.1f} ms".format((time.time() - t2)*1000))
  189. # plot tracking results
  190. online_im = plot_tracking(
  191. image, online_xywhs, online_ids, frame_id=frame_id + 1, fps=1. / (time.time() - t0)
  192. )
  193. else:
  194. online_im = image
  195. # save results
  196. if args.save:
  197. vid_writer.write(online_im)
  198. # show results
  199. if args.show:
  200. cv2.imshow('tracking', online_im)
  201. ch = cv2.waitKey(1)
  202. if ch == 27 or ch == ord("q") or ch == ord("Q"):
  203. break
  204. frame_id += 1
  205. cv2.destroyAllWindows()
  206. # ------------------------- Video ---------------------------
  207. elif args.mode == 'video':
  208. # read a video
  209. video = cv2.VideoCapture(args.path_to_vid)
  210. width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
  211. height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
  212. fps = cap.get(cv2.CAP_PROP_FPS)
  213. # path to save
  214. timestamp = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
  215. save_path = os.path.join(save_path, timestamp, args.path.split("/")[-1])
  216. vid_writer = cv2.VideoWriter(
  217. save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
  218. )
  219. print("Save path: {}".format(save_path))
  220. # start tracking
  221. frame_id = 0
  222. results = []
  223. while(True):
  224. ret, frame = video.read()
  225. if ret:
  226. # ------------------------- Detection ---------------------------
  227. # preprocess
  228. x, _, deltas = transform(frame)
  229. x = x.unsqueeze(0).to(device) / 255.
  230. orig_h, orig_w, _ = frame.shape
  231. # detect
  232. t0 = time.time()
  233. bboxes, scores, labels = detector(x)
  234. print("=============== Frame-{} ================".format(frame_id))
  235. print("detect time: {:.1f} ms".format((time.time() - t0)*1000))
  236. # rescale bboxes
  237. origin_img_size = [orig_h, orig_w]
  238. cur_img_size = [*x.shape[-2:]]
  239. bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
  240. # track
  241. t2 = time.time()
  242. if len(bboxes) > 0:
  243. online_targets = tracker.update(scores, bboxes, labels)
  244. online_xywhs = []
  245. online_ids = []
  246. online_scores = []
  247. for t in online_targets:
  248. xywh = t.xywh
  249. tid = t.track_id
  250. vertical = xywh[2] / xywh[3] > args.aspect_ratio_thresh
  251. if xywh[2] * xywh[3] > args.min_box_area and not vertical:
  252. online_xywhs.append(xywh)
  253. online_ids.append(tid)
  254. online_scores.append(t.score)
  255. results.append(
  256. f"{frame_id},{tid},{xywh[0]:.2f},{xywh[1]:.2f},{xywh[2]:.2f},{xywh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
  257. )
  258. print("tracking time: {:.1f} ms".format((time.time() - t2)*1000))
  259. # plot tracking results
  260. online_im = plot_tracking(
  261. frame, online_xywhs, online_ids, frame_id=frame_id + 1, fps=1. / (time.time() - t0)
  262. )
  263. else:
  264. online_im = frame
  265. # save results
  266. if args.save:
  267. vid_writer.write(online_im)
  268. # show results
  269. if args.show:
  270. cv2.imshow('tracking', online_im)
  271. ch = cv2.waitKey(1)
  272. if ch == 27 or ch == ord("q") or ch == ord("Q"):
  273. break
  274. else:
  275. break
  276. frame_id += 1
  277. video.release()
  278. vid_writer.release()
  279. cv2.destroyAllWindows()
  280. if __name__ == '__main__':
  281. args = parse_args()
  282. # cuda
  283. if args.cuda:
  284. print('use cuda')
  285. device = torch.device("cuda")
  286. else:
  287. device = torch.device("cpu")
  288. np.random.seed(0)
  289. # config
  290. model_cfg = build_model_config(args)
  291. trans_cfg = build_trans_config(model_cfg['trans_type'])
  292. # transform
  293. transform = build_transform(args.img_size, trans_cfg, is_train=False)
  294. # ---------------------- General Object Detector ----------------------
  295. detector = build_model(args, model_cfg, device, args.num_classes, False)
  296. ## load trained weight
  297. detector = load_weight(detector, args.weight, args.fuse_conv_bn)
  298. detector.to(device).eval()
  299. # ---------------------- General Object Tracker ----------------------
  300. tracker = build_tracker(args)
  301. # run
  302. run(args=args,
  303. tracker=tracker,
  304. detector=detector,
  305. device=device,
  306. transform=transform)