Bläddra i källkod

update README

yjh0410 2 år sedan
förälder
incheckning
8c1d66b78f

+ 54 - 24
README.md

@@ -62,7 +62,7 @@ python dataset/voc.py
 
 For example:
 ```Shell
-python train.py --cuda -d voc --root path/to/VOCdevkit -v yolov1 -bs 16 --max_epoch 150 --wp_epoch 1 --eval_epoch 10 --fp16 --ema --multi_scale
+python train.py --cuda -d voc --root path/to/VOCdevkit -m yolov1 -bs 16 --max_epoch 150 --wp_epoch 1 --eval_epoch 10 --fp16 --ema --multi_scale
 ```
 
 | Model        |   Backbone          | Scale |  IP  | Epoch | AP<sup>val<br>0.5 | FPS<sup>3090<br>FP32-bs1 | Weight |
@@ -95,7 +95,7 @@ python dataset/coco.py
 
 For example:
 ```Shell
-python train.py --cuda -d coco --root path/to/COCO -v yolov1 -bs 16 --max_epoch 150 --wp_epoch 1 --eval_epoch 10 --fp16 --ema --multi_scale
+python train.py --cuda -d coco --root path/to/COCO -m yolov1 -bs 16 --max_epoch 150 --wp_epoch 1 --eval_epoch 10 --fp16 --ema --multi_scale
 ```
 
 Due to my limited computing resources, I had to set the batch size to 16 or even smaller during training. I found that for small models such as *-Nano or *-Tiny, their performance seems less sensitive to batch size, such as the YOLOv5-N and S I reproduced, which are even slightly stronger than the official YOLOv5-N and S. However, for large models such as *-Large, their performance is significantly lower than the official performance, which seems to indicate that the large model is more sensitive to batch size.
@@ -167,7 +167,7 @@ You can change the configurations of `train.sh`, according to your own situation
 
 You also can add `--vis_tgt`  to check the images and targets during the training stage. For example:
 ```Shell
-python train.py --cuda -d coco --root path/to/coco -v yolov1 --vis_tgt
+python train.py --cuda -d coco --root path/to/coco -m yolov1 --vis_tgt
 ```
 
 ### Multi GPUs
@@ -184,7 +184,7 @@ weight path (`None` by default) to resume training. For example:
 python train.py \
         --cuda \
         -d coco \
-        -v yolov1 \
+        -m yolov1 \
         -bs 16 \
         --max_epoch 300 \
         --wp_epoch 3 \
@@ -200,7 +200,7 @@ Then, training will continue from 151 epoch.
 ```Shell
 python test.py -d coco \
                --cuda \
-               -v yolov1 \
+               -m yolov1 \
                --img_size 640 \
                --weight path/to/weight \
                --root path/to/dataset/ \
@@ -211,7 +211,7 @@ For YOLOv7, since it uses the RepConv in PaFPN, you can add `--fuse_repconv` to
 ```Shell
 python test.py -d coco \
                --cuda \
-               -v yolov7_large \
+               -m yolov7_large \
                --fuse_repconv \
                --img_size 640 \
                --weight path/to/weight \
@@ -224,7 +224,7 @@ python test.py -d coco \
 ```Shell
 python eval.py -d coco-val \
                --cuda \
-               -v yolov1 \
+               -m yolov1 \
                --img_size 640 \
                --weight path/to/weight \
                --root path/to/dataset/ \
@@ -237,10 +237,11 @@ I have provide some images in `data/demo/images/`, so you can run following comm
 ```Shell
 python demo.py --mode image \
                --path_to_img data/demo/images/ \
-               -v yolov1 \
-               --img_size 640 \
                --cuda \
-               --weight path/to/weight
+               --img_size 640 \
+               -m yolov1 \
+               --weight path/to/weight \
+               --show
 ```
 
 If you want run a demo of streaming video detection, you need to set `--mode` to `video`, and give the path to video `--path_to_vid`。
@@ -248,20 +249,24 @@ If you want run a demo of streaming video detection, you need to set `--mode` to
 ```Shell
 python demo.py --mode video \
                --path_to_img data/demo/videos/your_video \
-               -v yolov1 \
-               --img_size 640 \
                --cuda \
-               --weight path/to/weight
+               --img_size 640 \
+               -m yolov1 \
+               --weight path/to/weight \
+               --show \
+               --gif
 ```
 
 If you want run video detection with your camera, you need to set `--mode` to `camera`。
 
 ```Shell
 python demo.py --mode camera \
-               -v yolov1 \
-               --img_size 640 \
                --cuda \
-               --weight path/to/weight
+               --img_size 640 \
+               -m yolov1 \
+               --weight path/to/weight \
+               --show \
+               --gif
 ```
 
 ## Tracking
@@ -271,12 +276,13 @@ Our project also supports **multi-object tracking** tasks. We use the YOLO of th
 ```Shell
 python track.py --mode image \
                 --path_to_img path/to/images/ \
+                --cuda \
+                -size 640 \
                 -dt yolov2 \
                 -tk byte_tracker \
                 --weight path/to/coco_pretrained/ \
-                -size 640 \
-                --cuda \
-                --show
+                --show \
+                --gif
 ```
 
 * video tracking
@@ -284,22 +290,46 @@ python track.py --mode image \
 ```Shell
 python track.py --mode video \
                 --path_to_img path/to/video/ \
+                --cuda \
+                -size 640 \
                 -dt yolov2 \
                 -tk byte_tracker \
                 --weight path/to/coco_pretrained/ \
-                -size 640 \
-                --cuda \
-                --show
+                --show \
+                --gif
 ```
 
 * camera tracking
 
 ```Shell
 python track.py --mode camera \
+                --cuda \
+                -size 640 \
                 -dt yolov2 \
                 -tk byte_tracker \
                 --weight path/to/coco_pretrained/ \
-                -size 640 \
+                --show \
+                --gif
+```
+
+### Tracking visualization
+* Detector: YOLOv2
+* Tracker: ByteTracker
+
+Command:
+
+```Shell
+python track.py --mode video \
+                --path_to_img ./dataset/demo/videos/000006.mp4 \
                 --cuda \
-                --show
+                -size 640 \
+                -dt yolov2 \
+                -tk byte_tracker \
+                --weight path/to/coco_pretrained/ \
+                --show \
+                --gif
 ```
+
+Results:
+
+![image](./img_files/video_tracking_demo.gif)

+ 47 - 17
README_CN.md

@@ -243,10 +243,11 @@ python eval.py -d coco-val \
 ```Shell
 python demo.py --mode image \
                --path_to_img data/demo/images/ \
-               -v yolov1 \
-               --img_size 640 \
                --cuda \
-               --weight path/to/weight
+               --img_size 640 \
+               -m yolov2 \
+               --weight path/to/weight \
+               --show
 ```
 
 如果使用者想在本地的视频上去做测试,那么你需要将上述命令中的`--mode image`修改为`--mode video`,并给`--path_to_vid`传入视频所在的文件路径,例如:
@@ -254,20 +255,24 @@ python demo.py --mode image \
 ```Shell
 python demo.py --mode video \
                --path_to_img data/demo/videos/your_video \
-               -v yolov1 \
-               --img_size 640 \
                --cuda \
-               --weight path/to/weight
+               --img_size 640 \
+               -m yolov2 \
+               --weight path/to/weight \
+               --show \
+               --gif
 ```
 
 如果使用者想用本地的摄像头(如笔记本的摄像头)去做测试,那么你需要将上述命令中的`--mode image`修改为`--mode camera`,例如:
 
 ```Shell
 python demo.py --mode camera \
-               -v yolov1 \
-               --img_size 640 \
                --cuda \
-               --weight path/to/weight
+               --img_size 640 \
+               -m yolov2 \
+               --weight path/to/weight \
+               --show \
+               --gif
 ```
 
 
@@ -278,12 +283,13 @@ python demo.py --mode camera \
 ```Shell
 python track.py --mode image \
                 --path_to_img path/to/images/ \
+                --cuda \
+                -size 640 \
                 -dt yolov2 \
                 -tk byte_tracker \
                 --weight path/to/coco_pretrained/ \
-                -size 640 \
-                --cuda \
-                --show
+                --show \
+                --gif
 ```
 
 * video tracking
@@ -291,22 +297,46 @@ python track.py --mode image \
 ```Shell
 python track.py --mode video \
                 --path_to_img path/to/video/ \
+                --cuda \
+                -size 640 \
                 -dt yolov2 \
                 -tk byte_tracker \
                 --weight path/to/coco_pretrained/ \
-                -size 640 \
-                --cuda \
-                --show
+                --show \
+                --gif
 ```
 
 * camera tracking
 
 ```Shell
 python track.py --mode camera \
+                --cuda \
+                -size 640 \
                 -dt yolov2 \
                 -tk byte_tracker \
                 --weight path/to/coco_pretrained/ \
-                -size 640 \
+                --show \
+                --gif
+```
+
+### 多目标跟踪的例子
+* Detector: YOLOv2
+* Tracker: ByteTracker
+
+运行命令如下:
+
+```Shell
+python track.py --mode video \
+                --path_to_img ./dataset/demo/videos/000006.mp4 \
                 --cuda \
-                --show
+                -size 640 \
+                -dt yolov2 \
+                -tk byte_tracker \
+                --weight path/to/coco_pretrained/ \
+                --show \
+                --gif
 ```
+
+结果如下:
+
+![image](./img_files/video_tracking_demo.gif)

BIN
dataset/demo/images/000000000632.jpg


BIN
dataset/demo/images/000000000785.jpg


BIN
dataset/demo/images/000000000872.jpg


BIN
dataset/demo/images/000000000885.jpg


BIN
dataset/demo/images/000000001000.jpg


BIN
dataset/demo/images/000000001268.jpg


BIN
dataset/demo/images/000000001296.jpg


BIN
dataset/demo/images/000000001503.jpg


BIN
dataset/demo/images/000000001532.jpg


BIN
dataset/demo/videos/000006.mp4


+ 287 - 0
demo.py

@@ -0,0 +1,287 @@
+import argparse
+import cv2
+import os
+import time
+import numpy as np
+import imageio
+
+import torch
+
+# load transform
+from dataset.data_augment import build_transform
+
+# load some utils
+from utils.misc import load_weight
+from utils.box_ops import rescale_bboxes
+from utils.vis_tools import visualize
+
+from models.detectors import build_model
+from config import build_model_config, build_trans_config
+
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='YOLO Demo')
+
+    # basic
+    parser.add_argument('-size', '--img_size', default=640, type=int,
+                        help='the max size of input image')
+    parser.add_argument('--mode', default='image',
+                        type=str, help='Use the data from image, video or camera')
+    parser.add_argument('--cuda', action='store_true', default=False,
+                        help='Use cuda')
+    parser.add_argument('--path_to_img', default='dataset/demo/images/',
+                        type=str, help='The path to image files')
+    parser.add_argument('--path_to_vid', default='dataset/demo/videos/',
+                        type=str, help='The path to video files')
+    parser.add_argument('--path_to_save', default='det_results/demos/',
+                        type=str, help='The path to save the detection results')
+    parser.add_argument('-vt', '--vis_thresh', default=0.4, type=float,
+                        help='Final confidence threshold for visualization')
+    parser.add_argument('--show', action='store_true', default=False,
+                        help='show visualization')
+    parser.add_argument('--gif', action='store_true', default=False, 
+                        help='generate gif.')
+
+    # model
+    parser.add_argument('-m', '--model', default='yolov1', type=str,
+                        help='build yolo')
+    parser.add_argument('-nc', '--num_classes', default=80, type=int,
+                        help='number of classes.')
+    parser.add_argument('--weight', default=None,
+                        type=str, help='Trained state_dict file path to open')
+    parser.add_argument('-ct', '--conf_thresh', default=0.1, type=float,
+                        help='confidence threshold')
+    parser.add_argument('-nt', '--nms_thresh', default=0.5, type=float,
+                        help='NMS threshold')
+    parser.add_argument('--topk', default=100, type=int,
+                        help='topk candidates for testing')
+    parser.add_argument("--deploy", action="store_true", default=False,
+                        help="deploy mode or not")
+    parser.add_argument('--fuse_repconv', action='store_true', default=False,
+                        help='fuse RepConv')
+    parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
+                        help='fuse Conv & BN')
+
+    return parser.parse_args()
+                    
+
+def detect(args,
+           model, 
+           device, 
+           transform, 
+           vis_thresh, 
+           mode='image'):
+    # class color
+    np.random.seed(0)
+    class_colors = [(np.random.randint(255),
+                     np.random.randint(255),
+                     np.random.randint(255)) for _ in range(80)]
+    save_path = os.path.join(args.path_to_save, mode)
+    os.makedirs(save_path, exist_ok=True)
+
+    # ------------------------- Camera ----------------------------
+    if mode == 'camera':
+        print('use camera !!!')
+        fourcc = cv2.VideoWriter_fourcc(*'XVID')
+        save_size = (640, 480)
+        cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
+        save_video_name = os.path.join(save_path, cur_time+'.avi')
+        fps = 15.0
+        out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
+        print(save_video_name)
+        image_list = []
+
+        cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
+        while True:
+            ret, frame = cap.read()
+            if ret:
+                if cv2.waitKey(1) == ord('q'):
+                    break
+                orig_h, orig_w, _ = frame.shape
+
+                # prepare
+                x, _, deltas = transform(frame)
+                x = x.unsqueeze(0).to(device) / 255.
+                
+                # inference
+                t0 = time.time()
+                bboxes, scores, labels = model(x)
+                t1 = time.time()
+                print("detection time used ", t1-t0, "s")
+
+                # rescale bboxes
+                origin_img_size = [orig_h, orig_w]
+                cur_img_size = [*x.shape[-2:]]
+                bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
+
+                # vis detection
+                frame_vis = visualize(img=frame, 
+                                      bboxes=bboxes,
+                                      scores=scores, 
+                                      labels=labels,
+                                      class_colors=class_colors,
+                                      vis_thresh=vis_thresh)
+                frame_resized = cv2.resize(frame_vis, save_size)
+                out.write(frame_resized)
+
+                if args.gif:
+                    gif_resized = cv2.resize(frame, (640, 480))
+                    gif_resized_rgb = gif_resized[..., (2, 1, 0)]
+                    image_list.append(gif_resized_rgb)
+
+                if args.show:
+                    cv2.imshow('detection', frame_resized)
+                    cv2.waitKey(1)
+            else:
+                break
+        cap.release()
+        out.release()
+        cv2.destroyAllWindows()
+
+        # generate GIF
+        if args.gif:
+            save_gif_path =  os.path.join(save_path, 'gif_files')
+            os.makedirs(save_gif_path, exist_ok=True)
+            save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
+            print('generating GIF ...')
+            imageio.mimsave(save_gif_name, image_list, fps=fps)
+            print('GIF done: {}'.format(save_gif_name))
+
+    # ------------------------- Video ---------------------------
+    elif mode == 'video':
+        video = cv2.VideoCapture(args.path_to_vid)
+        fourcc = cv2.VideoWriter_fourcc(*'XVID')
+        save_size = (640, 480)
+        cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
+        save_video_name = os.path.join(save_path, cur_time+'.avi')
+        fps = 15.0
+        out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
+        print(save_video_name)
+        image_list = []
+
+        while(True):
+            ret, frame = video.read()
+            
+            if ret:
+                # ------------------------- Detection ---------------------------
+                orig_h, orig_w, _ = frame.shape
+
+                # prepare
+                x, _, deltas = transform(frame)
+                x = x.unsqueeze(0).to(device) / 255.
+
+                # inference
+                t0 = time.time()
+                bboxes, scores, labels = model(x)
+                t1 = time.time()
+                print("detection time used ", t1-t0, "s")
+
+                # rescale bboxes
+                origin_img_size = [orig_h, orig_w]
+                cur_img_size = [*x.shape[-2:]]
+                bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
+
+                # vis detection
+                frame_vis = visualize(img=frame, 
+                                      bboxes=bboxes,
+                                      scores=scores, 
+                                      labels=labels,
+                                      class_colors=class_colors,
+                                      vis_thresh=vis_thresh)
+
+                frame_resized = cv2.resize(frame_vis, save_size)
+                out.write(frame_resized)
+
+                if args.gif:
+                    gif_resized = cv2.resize(frame, (640, 480))
+                    gif_resized_rgb = gif_resized[..., (2, 1, 0)]
+                    image_list.append(gif_resized_rgb)
+
+                if args.show:
+                    cv2.imshow('detection', frame_resized)
+                    cv2.waitKey(1)
+            else:
+                break
+        video.release()
+        out.release()
+        cv2.destroyAllWindows()
+
+        # generate GIF
+        if args.gif:
+            save_gif_path =  os.path.join(save_path, 'gif_files')
+            os.makedirs(save_gif_path, exist_ok=True)
+            save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
+            print('generating GIF ...')
+            imageio.mimsave(save_gif_name, image_list, fps=fps)
+            print('GIF done: {}'.format(save_gif_name))
+
+    # ------------------------- Image ----------------------------
+    elif mode == 'image':
+        for i, img_id in enumerate(os.listdir(args.path_to_img)):
+            image = cv2.imread((args.path_to_img + '/' + img_id), cv2.IMREAD_COLOR)
+            orig_h, orig_w, _ = image.shape
+
+            # prepare
+            x, _, deltas = transform(image)
+            x = x.unsqueeze(0).to(device) / 255.
+
+            # inference
+            t0 = time.time()
+            bboxes, scores, labels = model(x)
+            t1 = time.time()
+            print("detection time used ", t1-t0, "s")
+
+            # rescale bboxes
+            origin_img_size = [orig_h, orig_w]
+            cur_img_size = [*x.shape[-2:]]
+            bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
+
+            # vis detection
+            img_processed = visualize(img=image, 
+                                      bboxes=bboxes,
+                                      scores=scores, 
+                                      labels=labels,
+                                      class_colors=class_colors,
+                                      vis_thresh=vis_thresh)
+            cv2.imwrite(os.path.join(save_path, str(i).zfill(6)+'.jpg'), img_processed)
+            if args.show:
+                cv2.imshow('detection', img_processed)
+                cv2.waitKey(0)
+
+
+def run():
+    args = parse_args()
+    # cuda
+    if args.cuda:
+        print('use cuda')
+        device = torch.device("cuda")
+    else:
+        device = torch.device("cpu")
+
+    # config
+    model_cfg = build_model_config(args)
+    trans_cfg = build_trans_config(model_cfg['trans_type'])
+
+    # build model
+    model = build_model(args, model_cfg, device, args.num_classes, False)
+
+    # load trained weight
+    model = load_weight(model, args.weight, args.fuse_conv_bn)
+    model.to(device).eval()
+
+    # transform
+    transform = build_transform(args.img_size, trans_cfg, is_train=False)
+
+    print("================= DETECT =================")
+    # run
+    detect(args=args,
+           model=model, 
+            device=device,
+            transform=transform,
+            mode=args.mode,
+            vis_thresh=args.vis_thresh)
+
+
+if __name__ == '__main__':
+    run()

BIN
img_files/video_tracking_demo.gif


+ 25 - 0
requirements.txt

@@ -0,0 +1,25 @@
+torch
+
+torchvision
+
+opencv-python
+
+thop
+
+scipy
+
+matplotlib
+
+numpy
+
+imageio
+
+pycocotools
+
+onnxsim
+
+onnxruntime
+
+openvino
+
+loguru

+ 3 - 3
test.py

@@ -114,7 +114,7 @@ def test(args,
          model, 
          device, 
          dataset,
-         transforms=None,
+         transform=None,
          class_colors=None, 
          class_names=None, 
          class_indexs=None):
@@ -129,7 +129,7 @@ def test(args,
         orig_h, orig_w, _ = image.shape
 
         # prepare
-        x, _, deltas = transforms(image)
+        x, _, deltas = transform(image)
         x = x.unsqueeze(0).to(device) / 255.
 
         t0 = time.time()
@@ -214,7 +214,7 @@ if __name__ == '__main__':
          model=model, 
          device=device, 
          dataset=dataset,
-         transforms=transform,
+         transform=transform,
          class_colors=class_colors,
          class_names=class_names,
          class_indexs=class_indexs

+ 154 - 82
track.py

@@ -1,8 +1,10 @@
 import os
 import cv2
 import time
+import imageio
 import argparse
 import numpy as np
+
 import torch
 
 from dataset.data_augment import build_transform
@@ -43,11 +45,13 @@ def parse_args():
                         help='show results.')
     parser.add_argument('--save', action='store_true', default=False, 
                         help='save results.')
+    parser.add_argument('--gif', action='store_true', default=False, 
+                        help='generate gif.')
 
     # tracker
     parser.add_argument('-tk', '--tracker', default='byte_tracker', type=str,
                         help='build FreeTrack')
-    parser.add_argument("--track_thresh", type=float, default=0.5, 
+    parser.add_argument("--track_thresh", type=float, default=0.4, 
                         help="tracking confidence threshold")
     parser.add_argument("--track_buffer", type=int, default=30, 
                         help="the frames for keep lost tracks")
@@ -96,15 +100,27 @@ def run(args,
         detector,
         device, 
         transform):
-    save_path = os.path.join(args.path_to_save, args.mode)
+    save_path = os.path.join(args.path_to_save, 'tracking', args.mode)
     os.makedirs(save_path, exist_ok=True)
 
     # ------------------------- Camera ----------------------------
     if args.mode == 'camera':
         print('use camera !!!')
+        # Launch camera
         cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
         frame_id = 0
         results = []
+
+        # For saving
+        fourcc = cv2.VideoWriter_fourcc(*'XVID')
+        save_size = (640, 480)
+        cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
+        save_video_name = os.path.join(save_path, cur_time+'.avi')
+        fps = 15.0
+        out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
+        print(save_video_name)
+        image_list = []
+
         # start tracking
         while True:
             ret, frame = cap.read()
@@ -155,101 +171,51 @@ def run(args,
                 else:
                     online_im = frame
 
+                frame_resized = cv2.resize(online_im, save_size)
+                out.write(frame_resized)
+
+                if args.gif:
+                    gif_resized = cv2.resize(online_im, (640, 480))
+                    gif_resized_rgb = gif_resized[..., (2, 1, 0)]
+                    image_list.append(gif_resized_rgb)
+
                 # show results
                 if args.show:
                     cv2.imshow('tracking', online_im)
                     ch = cv2.waitKey(1)
                     if ch == 27 or ch == ord("q") or ch == ord("Q"):
                         break
-
             else:
                 break
             frame_id += 1
 
         cap.release()
+        out.release()
         cv2.destroyAllWindows()
 
-    # ------------------------- Image ----------------------------
-    elif args.mode == 'image':
-        files = get_image_list(args.path_to_img)
-        files.sort()
-        # start tracking
-        frame_id = 0
-        results = []
-        for frame_id, img_path in enumerate(files, 1):
-            image = cv2.imread(os.path.join(img_path))
-            # preprocess
-            x, _, deltas = transform(image)
-            x = x.unsqueeze(0).to(device) / 255.
-            orig_h, orig_w, _ = image.shape
+        # generate GIF
+        if args.gif:
+            save_gif_path =  os.path.join(save_path, 'gif_files')
+            os.makedirs(save_gif_path, exist_ok=True)
+            save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
+            print('generating GIF ...')
+            imageio.mimsave(save_gif_name, image_list, fps=fps)
+            print('GIF done: {}'.format(save_gif_name))
 
-            # detect
-            t0 = time.time()
-            bboxes, scores, labels = detector(x)
-            print("=============== Frame-{} ================".format(frame_id))
-            print("detect time: {:.1f} ms".format((time.time() - t0)*1000))
-
-            # rescale bboxes
-            origin_img_size = [orig_h, orig_w]
-            cur_img_size = [*x.shape[-2:]]
-            bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
-
-            # track
-            t2 = time.time()
-            if len(bboxes) > 0:
-                online_targets = tracker.update(scores, bboxes, labels)
-                online_xywhs = []
-                online_ids = []
-                online_scores = []
-                for t in online_targets:
-                    xywh = t.xywh
-                    tid = t.track_id
-                    vertical = xywh[2] / xywh[3] > args.aspect_ratio_thresh
-                    if xywh[2] * xywh[3] > args.min_box_area and not vertical:
-                        online_xywhs.append(xywh)
-                        online_ids.append(tid)
-                        online_scores.append(t.score)
-                        results.append(
-                            f"{frame_id},{tid},{xywh[0]:.2f},{xywh[1]:.2f},{xywh[2]:.2f},{xywh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
-                            )
-                print("tracking time: {:.1f} ms".format((time.time() - t2)*1000))
-                
-                # plot tracking results
-                online_im = plot_tracking(
-                    image, online_xywhs, online_ids, frame_id=frame_id + 1, fps=1. / (time.time() - t0)
-                )
-            else:
-                online_im = image
-
-            # save results
-            if args.save:
-                vid_writer.write(online_im)
-            # show results
-            if args.show:
-                cv2.imshow('tracking', online_im)
-                ch = cv2.waitKey(1)
-                if ch == 27 or ch == ord("q") or ch == ord("Q"):
-                    break
-
-            frame_id += 1
-
-        cv2.destroyAllWindows()
-            
     # ------------------------- Video ---------------------------
     elif args.mode == 'video':
         # read a video
         video = cv2.VideoCapture(args.path_to_vid)
-        width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)  # float
-        height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)  # float
-        fps = cap.get(cv2.CAP_PROP_FPS)
+        fps = video.get(cv2.CAP_PROP_FPS)
         
-        # path to save
-        timestamp = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
-        save_path = os.path.join(save_path, timestamp, args.path.split("/")[-1])
-        vid_writer = cv2.VideoWriter(
-            save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
-        )
-        print("Save path: {}".format(save_path))
+        # For saving
+        fourcc = cv2.VideoWriter_fourcc(*'XVID')
+        save_size = (640, 480)
+        cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
+        save_video_name = os.path.join(save_path, cur_time+'.avi')
+        out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
+        print(save_video_name)
+        image_list = []
 
         # start tracking
         frame_id = 0
@@ -302,9 +268,14 @@ def run(args,
                 else:
                     online_im = frame
 
-                # save results
-                if args.save:
-                    vid_writer.write(online_im)
+                frame_resized = cv2.resize(online_im, save_size)
+                out.write(frame_resized)
+
+                if args.gif:
+                    gif_resized = cv2.resize(online_im, (640, 480))
+                    gif_resized_rgb = gif_resized[..., (2, 1, 0)]
+                    image_list.append(gif_resized_rgb)
+
                 # show results
                 if args.show:
                     cv2.imshow('tracking', online_im)
@@ -316,9 +287,110 @@ def run(args,
             frame_id += 1
 
         video.release()
-        vid_writer.release()
+        out.release()
         cv2.destroyAllWindows()
 
+        # generate GIF
+        if args.gif:
+            save_gif_path =  os.path.join(save_path, 'gif_files')
+            os.makedirs(save_gif_path, exist_ok=True)
+            save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
+            print('generating GIF ...')
+            imageio.mimsave(save_gif_name, image_list, fps=fps)
+            print('GIF done: {}'.format(save_gif_name))
+
+    # ------------------------- Image ----------------------------
+    elif args.mode == 'image':
+        files = get_image_list(args.path_to_img)
+        files.sort()
+
+        # For saving
+        fourcc = cv2.VideoWriter_fourcc(*'XVID')
+        save_size = (640, 480)
+        cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
+        save_video_name = os.path.join(save_path, cur_time+'.avi')
+        out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
+        print(save_video_name)
+        image_list = []
+
+        # start tracking
+        frame_id = 0
+        results = []
+        for frame_id, img_path in enumerate(files, 1):
+            image = cv2.imread(os.path.join(img_path))
+            # preprocess
+            x, _, deltas = transform(image)
+            x = x.unsqueeze(0).to(device) / 255.
+            orig_h, orig_w, _ = image.shape
+
+            # detect
+            t0 = time.time()
+            bboxes, scores, labels = detector(x)
+            print("=============== Frame-{} ================".format(frame_id))
+            print("detect time: {:.1f} ms".format((time.time() - t0)*1000))
+
+            # rescale bboxes
+            origin_img_size = [orig_h, orig_w]
+            cur_img_size = [*x.shape[-2:]]
+            bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
+
+            # track
+            t2 = time.time()
+            if len(bboxes) > 0:
+                online_targets = tracker.update(scores, bboxes, labels)
+                online_xywhs = []
+                online_ids = []
+                online_scores = []
+                for t in online_targets:
+                    xywh = t.xywh
+                    tid = t.track_id
+                    vertical = xywh[2] / xywh[3] > args.aspect_ratio_thresh
+                    if xywh[2] * xywh[3] > args.min_box_area and not vertical:
+                        online_xywhs.append(xywh)
+                        online_ids.append(tid)
+                        online_scores.append(t.score)
+                        results.append(
+                            f"{frame_id},{tid},{xywh[0]:.2f},{xywh[1]:.2f},{xywh[2]:.2f},{xywh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
+                            )
+                print("tracking time: {:.1f} ms".format((time.time() - t2)*1000))
+                
+                # plot tracking results
+                online_im = plot_tracking(
+                    image, online_xywhs, online_ids, frame_id=frame_id + 1, fps=1. / (time.time() - t0)
+                )
+            else:
+                online_im = frame
+
+            frame_resized = cv2.resize(online_im, save_size)
+            out.write(frame_resized)
+
+            if args.gif:
+                gif_resized = cv2.resize(online_im, (640, 480))
+                gif_resized_rgb = gif_resized[..., (2, 1, 0)]
+                image_list.append(gif_resized_rgb)
+
+            # show results
+            if args.show:
+                cv2.imshow('tracking', online_im)
+                ch = cv2.waitKey(1)
+                if ch == 27 or ch == ord("q") or ch == ord("Q"):
+                    break
+
+            frame_id += 1
+
+        cv2.destroyAllWindows()
+        out.release()
+        cv2.destroyAllWindows()
+
+        # generate GIF
+        if args.gif:
+            save_gif_path =  os.path.join(save_path, 'gif_files')
+            os.makedirs(save_gif_path, exist_ok=True)
+            save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
+            print('generating GIF ...')
+            imageio.mimsave(save_gif_name, image_list, fps=fps)
+            print('GIF done: {}'.format(save_gif_name))
+
 
 if __name__ == '__main__':
     args = parse_args()