|
|
@@ -1,8 +1,10 @@
|
|
|
import os
|
|
|
import cv2
|
|
|
import time
|
|
|
+import imageio
|
|
|
import argparse
|
|
|
import numpy as np
|
|
|
+
|
|
|
import torch
|
|
|
|
|
|
from dataset.data_augment import build_transform
|
|
|
@@ -43,11 +45,13 @@ def parse_args():
|
|
|
help='show results.')
|
|
|
parser.add_argument('--save', action='store_true', default=False,
|
|
|
help='save results.')
|
|
|
+ parser.add_argument('--gif', action='store_true', default=False,
|
|
|
+ help='generate gif.')
|
|
|
|
|
|
# tracker
|
|
|
parser.add_argument('-tk', '--tracker', default='byte_tracker', type=str,
|
|
|
help='build FreeTrack')
|
|
|
- parser.add_argument("--track_thresh", type=float, default=0.5,
|
|
|
+ parser.add_argument("--track_thresh", type=float, default=0.4,
|
|
|
help="tracking confidence threshold")
|
|
|
parser.add_argument("--track_buffer", type=int, default=30,
|
|
|
help="the frames for keep lost tracks")
|
|
|
@@ -96,15 +100,27 @@ def run(args,
|
|
|
detector,
|
|
|
device,
|
|
|
transform):
|
|
|
- save_path = os.path.join(args.path_to_save, args.mode)
|
|
|
+ save_path = os.path.join(args.path_to_save, 'tracking', args.mode)
|
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
|
|
|
# ------------------------- Camera ----------------------------
|
|
|
if args.mode == 'camera':
|
|
|
print('use camera !!!')
|
|
|
+ # Launch camera
|
|
|
cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
|
|
|
frame_id = 0
|
|
|
results = []
|
|
|
+
|
|
|
+ # For saving
|
|
|
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
|
|
+ save_size = (640, 480)
|
|
|
+ cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
|
|
|
+ save_video_name = os.path.join(save_path, cur_time+'.avi')
|
|
|
+ fps = 15.0
|
|
|
+ out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
|
|
|
+ print(save_video_name)
|
|
|
+ image_list = []
|
|
|
+
|
|
|
# start tracking
|
|
|
while True:
|
|
|
ret, frame = cap.read()
|
|
|
@@ -155,101 +171,51 @@ def run(args,
|
|
|
else:
|
|
|
online_im = frame
|
|
|
|
|
|
+ frame_resized = cv2.resize(online_im, save_size)
|
|
|
+ out.write(frame_resized)
|
|
|
+
|
|
|
+ if args.gif:
|
|
|
+ gif_resized = cv2.resize(online_im, (640, 480))
|
|
|
+ gif_resized_rgb = gif_resized[..., (2, 1, 0)]
|
|
|
+ image_list.append(gif_resized_rgb)
|
|
|
+
|
|
|
# show results
|
|
|
if args.show:
|
|
|
cv2.imshow('tracking', online_im)
|
|
|
ch = cv2.waitKey(1)
|
|
|
if ch == 27 or ch == ord("q") or ch == ord("Q"):
|
|
|
break
|
|
|
-
|
|
|
else:
|
|
|
break
|
|
|
frame_id += 1
|
|
|
|
|
|
cap.release()
|
|
|
+ out.release()
|
|
|
cv2.destroyAllWindows()
|
|
|
|
|
|
- # ------------------------- Image ----------------------------
|
|
|
- elif args.mode == 'image':
|
|
|
- files = get_image_list(args.path_to_img)
|
|
|
- files.sort()
|
|
|
- # start tracking
|
|
|
- frame_id = 0
|
|
|
- results = []
|
|
|
- for frame_id, img_path in enumerate(files, 1):
|
|
|
- image = cv2.imread(os.path.join(img_path))
|
|
|
- # preprocess
|
|
|
- x, _, deltas = transform(image)
|
|
|
- x = x.unsqueeze(0).to(device) / 255.
|
|
|
- orig_h, orig_w, _ = image.shape
|
|
|
+ # generate GIF
|
|
|
+ if args.gif:
|
|
|
+ save_gif_path = os.path.join(save_path, 'gif_files')
|
|
|
+ os.makedirs(save_gif_path, exist_ok=True)
|
|
|
+ save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
|
|
|
+ print('generating GIF ...')
|
|
|
+ imageio.mimsave(save_gif_name, image_list, fps=fps)
|
|
|
+ print('GIF done: {}'.format(save_gif_name))
|
|
|
|
|
|
- # detect
|
|
|
- t0 = time.time()
|
|
|
- bboxes, scores, labels = detector(x)
|
|
|
- print("=============== Frame-{} ================".format(frame_id))
|
|
|
- print("detect time: {:.1f} ms".format((time.time() - t0)*1000))
|
|
|
-
|
|
|
- # rescale bboxes
|
|
|
- origin_img_size = [orig_h, orig_w]
|
|
|
- cur_img_size = [*x.shape[-2:]]
|
|
|
- bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
|
|
|
-
|
|
|
- # track
|
|
|
- t2 = time.time()
|
|
|
- if len(bboxes) > 0:
|
|
|
- online_targets = tracker.update(scores, bboxes, labels)
|
|
|
- online_xywhs = []
|
|
|
- online_ids = []
|
|
|
- online_scores = []
|
|
|
- for t in online_targets:
|
|
|
- xywh = t.xywh
|
|
|
- tid = t.track_id
|
|
|
- vertical = xywh[2] / xywh[3] > args.aspect_ratio_thresh
|
|
|
- if xywh[2] * xywh[3] > args.min_box_area and not vertical:
|
|
|
- online_xywhs.append(xywh)
|
|
|
- online_ids.append(tid)
|
|
|
- online_scores.append(t.score)
|
|
|
- results.append(
|
|
|
- f"{frame_id},{tid},{xywh[0]:.2f},{xywh[1]:.2f},{xywh[2]:.2f},{xywh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
|
|
|
- )
|
|
|
- print("tracking time: {:.1f} ms".format((time.time() - t2)*1000))
|
|
|
-
|
|
|
- # plot tracking results
|
|
|
- online_im = plot_tracking(
|
|
|
- image, online_xywhs, online_ids, frame_id=frame_id + 1, fps=1. / (time.time() - t0)
|
|
|
- )
|
|
|
- else:
|
|
|
- online_im = image
|
|
|
-
|
|
|
- # save results
|
|
|
- if args.save:
|
|
|
- vid_writer.write(online_im)
|
|
|
- # show results
|
|
|
- if args.show:
|
|
|
- cv2.imshow('tracking', online_im)
|
|
|
- ch = cv2.waitKey(1)
|
|
|
- if ch == 27 or ch == ord("q") or ch == ord("Q"):
|
|
|
- break
|
|
|
-
|
|
|
- frame_id += 1
|
|
|
-
|
|
|
- cv2.destroyAllWindows()
|
|
|
-
|
|
|
# ------------------------- Video ---------------------------
|
|
|
elif args.mode == 'video':
|
|
|
# read a video
|
|
|
video = cv2.VideoCapture(args.path_to_vid)
|
|
|
- width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
|
|
|
- height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
|
|
|
- fps = cap.get(cv2.CAP_PROP_FPS)
|
|
|
+ fps = video.get(cv2.CAP_PROP_FPS)
|
|
|
|
|
|
- # path to save
|
|
|
- timestamp = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
|
|
|
- save_path = os.path.join(save_path, timestamp, args.path.split("/")[-1])
|
|
|
- vid_writer = cv2.VideoWriter(
|
|
|
- save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
|
|
|
- )
|
|
|
- print("Save path: {}".format(save_path))
|
|
|
+ # For saving
|
|
|
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
|
|
+ save_size = (640, 480)
|
|
|
+ cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
|
|
|
+ save_video_name = os.path.join(save_path, cur_time+'.avi')
|
|
|
+ out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
|
|
|
+ print(save_video_name)
|
|
|
+ image_list = []
|
|
|
|
|
|
# start tracking
|
|
|
frame_id = 0
|
|
|
@@ -302,9 +268,14 @@ def run(args,
|
|
|
else:
|
|
|
online_im = frame
|
|
|
|
|
|
- # save results
|
|
|
- if args.save:
|
|
|
- vid_writer.write(online_im)
|
|
|
+ frame_resized = cv2.resize(online_im, save_size)
|
|
|
+ out.write(frame_resized)
|
|
|
+
|
|
|
+ if args.gif:
|
|
|
+ gif_resized = cv2.resize(online_im, (640, 480))
|
|
|
+ gif_resized_rgb = gif_resized[..., (2, 1, 0)]
|
|
|
+ image_list.append(gif_resized_rgb)
|
|
|
+
|
|
|
# show results
|
|
|
if args.show:
|
|
|
cv2.imshow('tracking', online_im)
|
|
|
@@ -316,9 +287,110 @@ def run(args,
|
|
|
frame_id += 1
|
|
|
|
|
|
video.release()
|
|
|
- vid_writer.release()
|
|
|
+ out.release()
|
|
|
cv2.destroyAllWindows()
|
|
|
|
|
|
+ # generate GIF
|
|
|
+ if args.gif:
|
|
|
+ save_gif_path = os.path.join(save_path, 'gif_files')
|
|
|
+ os.makedirs(save_gif_path, exist_ok=True)
|
|
|
+ save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
|
|
|
+ print('generating GIF ...')
|
|
|
+ imageio.mimsave(save_gif_name, image_list, fps=fps)
|
|
|
+ print('GIF done: {}'.format(save_gif_name))
|
|
|
+
|
|
|
+ # ------------------------- Image ----------------------------
|
|
|
+ elif args.mode == 'image':
|
|
|
+ files = get_image_list(args.path_to_img)
|
|
|
+ files.sort()
|
|
|
+
|
|
|
+ # For saving
|
|
|
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
|
|
+ save_size = (640, 480)
|
|
|
+ cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
|
|
|
+ save_video_name = os.path.join(save_path, cur_time+'.avi')
|
|
|
+ out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
|
|
|
+ print(save_video_name)
|
|
|
+ image_list = []
|
|
|
+
|
|
|
+ # start tracking
|
|
|
+ frame_id = 0
|
|
|
+ results = []
|
|
|
+ for frame_id, img_path in enumerate(files, 1):
|
|
|
+ image = cv2.imread(os.path.join(img_path))
|
|
|
+ # preprocess
|
|
|
+ x, _, deltas = transform(image)
|
|
|
+ x = x.unsqueeze(0).to(device) / 255.
|
|
|
+ orig_h, orig_w, _ = image.shape
|
|
|
+
|
|
|
+ # detect
|
|
|
+ t0 = time.time()
|
|
|
+ bboxes, scores, labels = detector(x)
|
|
|
+ print("=============== Frame-{} ================".format(frame_id))
|
|
|
+ print("detect time: {:.1f} ms".format((time.time() - t0)*1000))
|
|
|
+
|
|
|
+ # rescale bboxes
|
|
|
+ origin_img_size = [orig_h, orig_w]
|
|
|
+ cur_img_size = [*x.shape[-2:]]
|
|
|
+ bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
|
|
|
+
|
|
|
+ # track
|
|
|
+ t2 = time.time()
|
|
|
+ if len(bboxes) > 0:
|
|
|
+ online_targets = tracker.update(scores, bboxes, labels)
|
|
|
+ online_xywhs = []
|
|
|
+ online_ids = []
|
|
|
+ online_scores = []
|
|
|
+ for t in online_targets:
|
|
|
+ xywh = t.xywh
|
|
|
+ tid = t.track_id
|
|
|
+ vertical = xywh[2] / xywh[3] > args.aspect_ratio_thresh
|
|
|
+ if xywh[2] * xywh[3] > args.min_box_area and not vertical:
|
|
|
+ online_xywhs.append(xywh)
|
|
|
+ online_ids.append(tid)
|
|
|
+ online_scores.append(t.score)
|
|
|
+ results.append(
|
|
|
+ f"{frame_id},{tid},{xywh[0]:.2f},{xywh[1]:.2f},{xywh[2]:.2f},{xywh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
|
|
|
+ )
|
|
|
+ print("tracking time: {:.1f} ms".format((time.time() - t2)*1000))
|
|
|
+
|
|
|
+ # plot tracking results
|
|
|
+ online_im = plot_tracking(
|
|
|
+ image, online_xywhs, online_ids, frame_id=frame_id + 1, fps=1. / (time.time() - t0)
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ online_im = frame
|
|
|
+
|
|
|
+ frame_resized = cv2.resize(online_im, save_size)
|
|
|
+ out.write(frame_resized)
|
|
|
+
|
|
|
+ if args.gif:
|
|
|
+ gif_resized = cv2.resize(online_im, (640, 480))
|
|
|
+ gif_resized_rgb = gif_resized[..., (2, 1, 0)]
|
|
|
+ image_list.append(gif_resized_rgb)
|
|
|
+
|
|
|
+ # show results
|
|
|
+ if args.show:
|
|
|
+ cv2.imshow('tracking', online_im)
|
|
|
+ ch = cv2.waitKey(1)
|
|
|
+ if ch == 27 or ch == ord("q") or ch == ord("Q"):
|
|
|
+ break
|
|
|
+
|
|
|
+ frame_id += 1
|
|
|
+
|
|
|
+ cv2.destroyAllWindows()
|
|
|
+ out.release()
|
|
|
+ cv2.destroyAllWindows()
|
|
|
+
|
|
|
+ # generate GIF
|
|
|
+ if args.gif:
|
|
|
+ save_gif_path = os.path.join(save_path, 'gif_files')
|
|
|
+ os.makedirs(save_gif_path, exist_ok=True)
|
|
|
+ save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
|
|
|
+ print('generating GIF ...')
|
|
|
+ imageio.mimsave(save_gif_name, image_list, fps=fps)
|
|
|
+ print('GIF done: {}'.format(save_gif_name))
|
|
|
+
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
args = parse_args()
|