Преглед на файлове

optimize load_cache function

冬落 преди 2 години
родител
ревизия
4602d0a33e
променени са 5 файла, в които са добавени 391 реда и са изтрити 296 реда
  1. 3 3
      dataset/build.py
  2. 54 75
      dataset/coco.py
  3. 208 0
      dataset/make_dataset.py
  4. 52 80
      dataset/ourdataset.py
  5. 74 138
      dataset/voc.py

+ 3 - 3
dataset/build.py

@@ -1,14 +1,14 @@
 import os
 
 try:
-    from .voc import VOCDetection
+    from .voc import VOCDataset
     from .coco import COCODataset
     from .ourdataset import OurDataset
     from .data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
     from .data_augment.yolov5_augment import YOLOv5Augmentation, YOLOv5BaseTransform
 
 except:
-    from voc import VOCDetection
+    from voc import VOCDataset
     from coco import COCODataset
     from ourdataset import OurDataset
     from data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
@@ -31,7 +31,7 @@ def build_dataset(args, data_cfg, trans_config, transform, is_train=False):
     # ------------------------- Build dataset -------------------------
     ## VOC dataset
     if args.dataset == 'voc':
-        dataset = VOCDetection(
+        dataset = VOCDataset(
             img_size=args.img_size,
             data_dir=data_dir,
             image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')],

+ 54 - 75
dataset/coco.py

@@ -1,11 +1,11 @@
 import os
+import cv2
+import time
 import random
 import numpy as np
-import time
 
 import torch
 from torch.utils.data import Dataset
-import cv2
 
 try:
     from pycocotools.coco import COCO
@@ -44,13 +44,13 @@ class COCODataset(Dataset):
     COCO dataset class.
     """
     def __init__(self, 
-                 img_size=640,
-                 data_dir=None, 
-                 image_set='train2017',
-                 trans_config=None,
-                 transform=None,
-                 is_train=False,
-                 load_cache=False):
+                 img_size     :int = 640,
+                 data_dir     :str = None, 
+                 image_set    :str = 'train2017',
+                 trans_config = None,
+                 transform    = None,
+                 is_train     :bool =False,
+                 load_cache   :str  = None):
         """
         COCO dataset initialization. Annotation data are read into memory by COCO API.
         Args:
@@ -59,22 +59,25 @@ class COCODataset(Dataset):
             name (str): COCO data name (e.g. 'train2017' or 'val2017')
             debug (bool): if True, only one data id is selected from the dataset
         """
+        # ----------- Basic parameters -----------
+        self.img_size = img_size
+        self.image_set = image_set
+        self.is_train = is_train
+        self.load_cache = load_cache
+        # ----------- Path parameters -----------
+        self.data_dir = data_dir
         if image_set == 'train2017':
             self.json_file='instances_train2017.json'
         elif image_set == 'val2017':
             self.json_file='instances_val2017.json'
         elif image_set == 'test2017':
             self.json_file='image_info_test-dev2017.json'
-        self.img_size = img_size
-        self.image_set = image_set
-        self.data_dir = data_dir
+        # ----------- Data parameters -----------
         self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
         self.ids = self.coco.getImgIds()
         self.class_ids = sorted(self.coco.getCatIds())
-        self.is_train = is_train
-        self.load_cache = load_cache
-
-        # augmentation
+        self.dataset_size = len(self.ids)
+        # ----------- Transform parameters -----------
         self.transform = transform
         self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
         self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
@@ -85,72 +88,28 @@ class COCODataset(Dataset):
         print('==============================')
         
         # load cache data
-        if load_cache:
+        if load_cache is not None:
             self._load_cache()
 
-
+    # ------------ Basic dataset function ------------
     def __len__(self):
         return len(self.ids)
 
-
     def __getitem__(self, index):
         return self.pull_item(index)
 
-
     def _load_cache(self):
         # load image cache
-        self.cached_images = []
-        self.cached_targets = []
-        dataset_size = len(self.ids)
-
-        print('loading data into memory ...')
-        for i in range(dataset_size):
-            if i % 5000 == 0:
-                print("[{} / {}]".format(i, dataset_size))
-            # load an image
-            image, image_id = self.pull_image(i)
-            orig_h, orig_w, _ = image.shape
-
-            # resize image
-            r = self.img_size / max(orig_h, orig_w)
-            if r != 1: 
-                interp = cv2.INTER_LINEAR
-                new_size = (int(orig_w * r), int(orig_h * r))
-                image = cv2.resize(image, new_size, interpolation=interp)
-            img_h, img_w = image.shape[:2]
-            self.cached_images.append(image)
-
-            # load target cache
-            bboxes, labels = self.pull_anno(i)
-            bboxes[:, [0, 2]] = bboxes[:, [0, 2]] / orig_w * img_w
-            bboxes[:, [1, 3]] = bboxes[:, [1, 3]] / orig_h * img_h
-            self.cached_targets.append({"boxes": bboxes, "labels": labels})
-        
-
-    def load_image_target(self, index):
-        if self.load_cache:
-            # load data from cache
-            image = self.cached_images[index]
-            target = self.cached_targets[index]
-            height, width, channels = image.shape
-            target["orig_size"] = [height, width]
-        else:
-            # load an image
-            image, _ = self.pull_image(index)
-            height, width, channels = image.shape
-
-            # load a target
-            bboxes, labels = self.pull_anno(index)
-
-            target = {
-                "boxes": bboxes,
-                "labels": labels,
-                "orig_size": [height, width]
-            }
-
-        return image, target
-
-
+        try:
+            print("Loading cached data ...")
+            self.cached_datas = torch.load(self.load_cache)
+            self.dataset_size = len(self.cached_datas)
+            print("Loading done !")
+        except:
+            self.load_cache = None
+            print("{} does not exits.".format(self.load_cache))
+
+    # ------------ Mosaic & Mixup ------------
     def load_mosaic(self, index):
         # load 4x mosaic image
         index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
@@ -173,7 +132,6 @@ class COCODataset(Dataset):
 
         return image, target
 
-
     def load_mixup(self, origin_image, origin_target):
         # YOLOv5 type Mixup
         if self.trans_config['mixup_type'] == 'yolov5_mixup':
@@ -190,6 +148,29 @@ class COCODataset(Dataset):
 
         return image, target
     
+    # ------------ Load data function ------------
+    def load_image_target(self, index):
+        # == Load a data from the cached data ==
+        if self.load_cache and self.is_train:
+            # load a data
+            data_item = self.cached_datas[index]
+            image = data_item["image"]
+            target = data_item["target"]
+        # == Load a data from the local disk ==
+        else:        
+            # load an image
+            image, _ = self.pull_image(index)
+            height, width, channels = image.shape
+
+            # load a target
+            bboxes, labels = self.pull_anno(index)
+            target = {
+                "boxes": bboxes,
+                "labels": labels,
+                "orig_size": [height, width]
+            }
+
+        return image, target
 
     def pull_item(self, index):
         if random.random() < self.mosaic_prob:
@@ -210,7 +191,6 @@ class COCODataset(Dataset):
 
         return image, target, deltas
 
-
     def pull_image(self, index):
         img_id = self.ids[index]
         img_file = os.path.join(self.data_dir, self.image_set,
@@ -226,7 +206,6 @@ class COCODataset(Dataset):
 
         return image, img_id
 
-
     def pull_anno(self, index):
         img_id = self.ids[index]
         im_ann = self.coco.loadImgs(img_id)[0]

+ 208 - 0
dataset/make_dataset.py

@@ -0,0 +1,208 @@
+import os
+import cv2
+import torch
+import random
+import numpy as np
+
+import sys
+sys.path.append("../")
+from utils import distributed_utils
+from dataset.voc import VOCDataset, VOC_CLASSES
+from dataset.coco import COCODataset, coco_class_labels, coco_class_index
+from config import build_trans_config, build_dataset_config
+
+
+def fix_random_seed(args):
+    seed = args.seed + distributed_utils.get_rank()
+    torch.manual_seed(seed)
+    np.random.seed(seed)
+    random.seed(seed)
+
+# ------------------------------ Dataset ------------------------------
+def build_dataset(args, data_cfg, trans_config, transform, is_train=False):
+    # ------------------------- Basic parameters -------------------------
+    data_dir = os.path.join(args.root, data_cfg['data_name'])
+    num_classes = data_cfg['num_classes']
+    class_names = data_cfg['class_names']
+    class_indexs = data_cfg['class_indexs']
+    dataset_info = {
+        'num_classes': num_classes,
+        'class_names': class_names,
+        'class_indexs': class_indexs
+    }
+
+    # ------------------------- Build dataset -------------------------
+    ## VOC dataset
+    if args.dataset == 'voc':
+        dataset = VOCDataset(
+            img_size=args.img_size,
+            data_dir=data_dir,
+            image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')],
+            transform=transform,
+            trans_config=trans_config,
+            load_cache=args.load_cache
+            )
+    ## COCO dataset
+    elif args.dataset == 'coco':
+        dataset = COCODataset(
+            img_size=args.img_size,
+            data_dir=data_dir,
+            image_set='train2017' if is_train else 'val2017',
+            transform=transform,
+            trans_config=trans_config,
+            load_cache=args.load_cache
+            )
+
+    return dataset, dataset_info
+
+def visualize(image, target, dataset_name="voc"):
+    if dataset_name == "voc":
+        class_labels = VOC_CLASSES
+        class_indexs = None
+        num_classes  = 20
+    elif dataset_name == "coco":
+        class_labels = coco_class_labels
+        class_indexs = coco_class_index
+        num_classes  = 80
+    else:
+        raise NotImplementedError
+
+    class_colors = [(np.random.randint(255),
+                     np.random.randint(255),
+                     np.random.randint(255))
+                     for _ in range(num_classes)]
+
+    # to numpy
+    # image = image.permute(1, 2, 0).numpy()
+    image = image.astype(np.uint8)
+    image = image.copy()
+
+    boxes = target["boxes"]
+    labels = target["labels"]
+    for box, label in zip(boxes, labels):
+        x1, y1, x2, y2 = box
+        if x2 - x1 > 1 and y2 - y1 > 1:
+            cls_id = int(label)
+            color = class_colors[cls_id]
+            # class name
+            if dataset_name == 'coco':
+                assert class_indexs is not None
+                class_name = class_labels[class_indexs[cls_id]]
+            else:
+                class_name = class_labels[cls_id]
+            image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
+            # put the test on the bbox
+            cv2.putText(image, class_name, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
+    cv2.imshow('gt', image)
+    cv2.waitKey(0)
+
+
+if __name__ == "__main__":
+    import argparse
+    from build import build_transform
+    
+    parser = argparse.ArgumentParser(description='VOC-Dataset')
+
+    # Seed
+    parser.add_argument('--seed', default=42, type=int)
+    # Dataset
+    parser.add_argument('--root', default='/Users/yjh0410/Desktop/python_work/dataset/',
+                        help='data root')
+    parser.add_argument('--dataset', type=str, default="voc",
+                        help='augmentation type.')
+    parser.add_argument('--load_cache', action="store_true", default=False,
+                        help='load cached data.')
+    parser.add_argument('--vis_tgt', action="store_true", default=False,
+                        help='load cached data.')
+    parser.add_argument('--is_train', action="store_true", default=False,
+                        help='mixup augmentation.')
+    # Image size
+    parser.add_argument('-size', '--img_size', default=640, type=int,
+                        help='input image size.')
+    # Augmentation
+    parser.add_argument('--aug_type', type=str, default="yolov5_nano",
+                        help='augmentation type.')
+    parser.add_argument('--mosaic', default=None, type=float,
+                        help='mosaic augmentation.')
+    parser.add_argument('--mixup', default=None, type=float,
+                        help='mixup augmentation.')
+    # DDP train
+    parser.add_argument('-dist', '--distributed', action='store_true', default=False,
+                        help='distributed training')
+    parser.add_argument('--dist_url', default='env://', 
+                        help='url used to set up distributed training')
+    parser.add_argument('--world_size', default=1, type=int,
+                        help='number of distributed processes')
+    parser.add_argument('--sybn', action='store_true', default=False, 
+                        help='use sybn.')
+    # Output
+    parser.add_argument('--output_dir', type=str, default='cache_data/',
+                        help='data root')
+    
+    args = parser.parse_args()
+
+    
+    assert args.aug_type in ["yolov5_pico", "yolov5_nano", "yolov5_small", "yolov5_medium", "yolov5_large", "yolov5_huge",
+                             "yolox_pico",  "yolox_nano",  "yolox_small",  "yolox_medium",  "yolox_large",  "yolox_huge"]
+    
+
+    # ------------- Build transform config -------------
+    dataset_cfg  = build_dataset_config(args)
+    trans_config = build_trans_config(args.aug_type)
+
+    # ------------- Build transform -------------
+    transform, trans_config = build_transform(args, trans_config, max_stride=32, is_train=args.is_train)
+
+    # ------------- Build dataset -------------
+    dataset, dataset_info = build_dataset(args, dataset_cfg, trans_config, transform, is_train=args.is_train)
+    print('Data length: ', len(dataset))
+
+    # ---------------------------- Fix random seed ----------------------------
+    fix_random_seed(args)
+
+    # ---------------------------- Main process ----------------------------
+    # We only cache the taining data
+    data_items = []
+    for idx in range(len(dataset)):
+        if idx % 2000 == 0:
+            print("Caching images and targets : {} / {} ...".format(idx, len(dataset)))
+
+        # load a data
+        image, target = dataset.load_image_target(idx)
+        orig_h, orig_w, _ = image.shape
+
+        # resize image
+        r = args.img_size / max(orig_h, orig_w)
+        if r != 1: 
+            interp = cv2.INTER_LINEAR
+            new_size = (int(orig_w * r), int(orig_h * r))
+            image = cv2.resize(image, new_size, interpolation=interp)
+        img_h, img_w = image.shape[:2]
+
+        # rescale bbox
+        boxes = target["boxes"].copy()
+        boxes[:, [0, 2]] = boxes[:, [0, 2]] / orig_w * img_w
+        boxes[:, [1, 3]] = boxes[:, [1, 3]] / orig_h * img_h
+        target["boxes"] = boxes
+
+        # visualize data
+        if args.vis_tgt:
+            print(image.shape)
+            visualize(image, target, args.dataset)
+            continue
+
+        dict_item = {}
+        dict_item["image"] = image
+        dict_item["target"] = target
+
+        data_items.append(dict_item)
+
+    output_dir = os.path.join(args.output_dir, args.dataset)
+    os.makedirs(output_dir, exist_ok=True)
+
+    print('Cached data size: ', len(data_items))
+    if args.is_train:
+        save_file = os.path.join(output_dir, "{}_train.pth".format(args.dataset))
+    else:
+        save_file = os.path.join(output_dir, "{}_valid.pth".format(args.dataset))
+    torch.save(data_items, save_file)

+ 52 - 80
dataset/ourdataset.py

@@ -1,9 +1,10 @@
 import os
 import cv2
+import time
 import random
 import numpy as np
-import time
 
+import torch
 from torch.utils.data import Dataset
 
 try:
@@ -26,32 +27,27 @@ class OurDataset(Dataset):
     Our dataset class.
     """
     def __init__(self, 
-                 img_size=640,
-                 data_dir=None, 
-                 image_set='train',
-                 transform=None,
-                 trans_config=None,
-                 is_train=False,
-                 load_cache=False):
-        """
-        COCO dataset initialization. Annotation data are read into memory by COCO API.
-        Args:
-            data_dir (str): dataset root directory
-            json_file (str): COCO json file name
-            name (str): COCO data name (e.g. 'train2017' or 'val2017')
-            debug (bool): if True, only one data id is selected from the dataset
-        """
+                 img_size     :int  = 640,
+                 data_dir     :str  = None, 
+                 image_set    :str  = 'train',
+                 transform          = None,
+                 trans_config       = None,
+                 is_train     :bool = False,
+                 load_cache   :str  = None):
+        # ----------- Basic parameters -----------
         self.img_size = img_size
         self.image_set = image_set
-        self.json_file = '{}.json'.format(image_set)
+        self.is_train = is_train
+        self.load_cache = load_cache
+        # ----------- Path parameters -----------
         self.data_dir = data_dir
+        self.json_file = '{}.json'.format(image_set)
+        # ----------- Data parameters -----------
         self.coco = COCO(os.path.join(self.data_dir, image_set, 'annotations', self.json_file))
         self.ids = self.coco.getImgIds()
         self.class_ids = sorted(self.coco.getCatIds())
-        self.is_train = is_train
-        self.load_cache = load_cache
-
-        # augmentation
+        self.dataset_size = len(self.ids)
+        # ----------- Transform parameters -----------
         self.transform = transform
         self.mosaic_prob = 0
         self.mixup_prob = 0
@@ -68,72 +64,28 @@ class OurDataset(Dataset):
         print('==============================')
 
         # load cache data
-        if load_cache:
+        if load_cache is not None and is_train:
             self._load_cache()
 
-
+    # ------------ Basic dataset function ------------
     def __len__(self):
         return len(self.ids)
 
-
     def __getitem__(self, index):
         return self.pull_item(index)
 
-
     def _load_cache(self):
         # load image cache
-        self.cached_images = []
-        self.cached_targets = []
-        dataset_size = len(self.ids)
-
-        print('loading data into memory ...')
-        for i in range(dataset_size):
-            if i % 5000 == 0:
-                print("[{} / {}]".format(i, dataset_size))
-            # load an image
-            image, image_id = self.pull_image(i)
-            orig_h, orig_w, _ = image.shape
-
-            # resize image
-            r = self.img_size / max(orig_h, orig_w)
-            if r != 1: 
-                interp = cv2.INTER_LINEAR
-                new_size = (int(orig_w * r), int(orig_h * r))
-                image = cv2.resize(image, new_size, interpolation=interp)
-            img_h, img_w = image.shape[:2]
-            self.cached_images.append(image)
-
-            # load target cache
-            bboxes, labels = self.pull_anno(i)
-            bboxes[:, [0, 2]] = bboxes[:, [0, 2]] / orig_w * img_w
-            bboxes[:, [1, 3]] = bboxes[:, [1, 3]] / orig_h * img_h
-            self.cached_targets.append({"boxes": bboxes, "labels": labels})
-        
-
-    def load_image_target(self, index):
-        if self.load_cache:
-            # load data from cache
-            image = self.cached_images[index]
-            target = self.cached_targets[index]
-            height, width, channels = image.shape
-            target["orig_size"] = [height, width]
-        else:
-            # load an image
-            image, _ = self.pull_image(index)
-            height, width, channels = image.shape
-
-            # load a target
-            bboxes, labels = self.pull_anno(index)
-
-            target = {
-                "boxes": bboxes,
-                "labels": labels,
-                "orig_size": [height, width]
-            }
-
-        return image, target
-
-
+        try:
+            print("Loading cached data ...")
+            self.cached_datas = torch.load(self.load_cache)
+            self.dataset_size = len(self.cached_datas)
+            print("Loading done !")
+        except:
+            self.load_cache = None
+            print("{} does not exits.".format(self.load_cache))
+
+    # ------------ Mosaic & Mixup ------------
     def load_mosaic(self, index):
         # load 4x mosaic image
         index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
@@ -156,7 +108,6 @@ class OurDataset(Dataset):
 
         return image, target
 
-
     def load_mixup(self, origin_image, origin_target):
         # YOLOv5 type Mixup
         if self.trans_config['mixup_type'] == 'yolov5_mixup':
@@ -173,6 +124,29 @@ class OurDataset(Dataset):
 
         return image, target
     
+    # ------------ Load data function ------------
+    def load_image_target(self, index):
+        # == Load a data from the cached data ==
+        if self.load_cache and self.is_train:
+            # load a data
+            data_item = self.cached_datas[index]
+            image = data_item["image"]
+            target = data_item["target"]
+        # == Load a data from the local disk ==
+        else:        
+            # load an image
+            image, _ = self.pull_image(index)
+            height, width, channels = image.shape
+
+            # load a target
+            bboxes, labels = self.pull_anno(index)
+            target = {
+                "boxes": bboxes,
+                "labels": labels,
+                "orig_size": [height, width]
+            }
+
+        return image, target
 
     def pull_item(self, index):
         if random.random() < self.mosaic_prob:
@@ -193,7 +167,6 @@ class OurDataset(Dataset):
 
         return image, target, deltas
 
-
     def pull_image(self, index):
         id_ = self.ids[index]
         im_ann = self.coco.loadImgs(id_)[0] 
@@ -203,7 +176,6 @@ class OurDataset(Dataset):
 
         return image, id_
 
-
     def pull_anno(self, index):
         img_id = self.ids[index]
         im_ann = self.coco.loadImgs(img_id)[0]

+ 74 - 138
dataset/voc.py

@@ -1,30 +1,21 @@
-"""VOC Dataset Classes
-
-Original author: Francisco Massa
-https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py
-
-Updated by: Ellis Brown, Max deGroot
-"""
-import os.path as osp
-import random
-import torch.utils.data as data
+import os
 import cv2
+import torch
+import random
 import numpy as np
+import os.path as osp
 import xml.etree.ElementTree as ET
 
+import torch
+import torch.utils.data as data
 try:
     from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
 except:
     from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
 
 
-
-VOC_CLASSES = (  # always index 0
-    'aeroplane', 'bicycle', 'bird', 'boat',
-    'bottle', 'bus', 'car', 'cat', 'chair',
-    'cow', 'diningtable', 'dog', 'horse',
-    'motorbike', 'person', 'pottedplant',
-    'sheep', 'sofa', 'train', 'tvmonitor')
+# VOC class names
+VOC_CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
 
 
 class VOCAnnotationTransform(object):
@@ -74,47 +65,34 @@ class VOCAnnotationTransform(object):
         return res  # [[x1, y1, x2, y2, label_ind], ... ]
 
 
-class VOCDetection(data.Dataset):
-    """VOC Detection Dataset Object
-
-    input is image, target is annotation
-
-    Arguments:
-        root (string): filepath to VOCdevkit folder.
-        image_set (string): imageset to use (eg. 'train', 'val', 'test')
-        transform (callable, optional): transformation to perform on the
-            input image
-        target_transform (callable, optional): transformation to perform on the
-            target `annotation`
-            (eg: take in caption string, return tensor of word indices)
-        dataset_name (string, optional): which dataset to load
-            (default: 'VOC2007')
-    """
-
+class VOCDataset(data.Dataset):
     def __init__(self, 
-                 img_size=640,
-                 data_dir=None,
-                 image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
-                 trans_config=None,
-                 transform=None,
-                 is_train=False,
-                 load_cache=False
+                 img_size     :int = 640,
+                 data_dir     :str = None,
+                 image_sets   = [('2007', 'trainval'), ('2012', 'trainval')],
+                 trans_config = None,
+                 transform    = None,
+                 is_train     :bool = False,
+                 load_cache   :str  = None,
                  ):
-        self.root = data_dir
+        # ----------- Basic parameters -----------
         self.img_size = img_size
         self.image_set = image_sets
+        self.is_train = is_train
+        self.load_cache = load_cache
         self.target_transform = VOCAnnotationTransform()
+        # ----------- Path parameters -----------
+        self.root = data_dir
         self._annopath = osp.join('%s', 'Annotations', '%s.xml')
         self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
+        # ----------- Data parameters -----------
         self.ids = list()
-        self.is_train = is_train
-        self.load_cache = load_cache
         for (year, name) in image_sets:
             rootpath = osp.join(self.root, 'VOC' + year)
             for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
                 self.ids.append((rootpath, line.strip()))
-
-        # augmentation
+        self.dataset_size = len(self.ids)
+        # ----------- Transform parameters -----------
         self.transform = transform
         self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
         self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
@@ -125,81 +103,29 @@ class VOCDetection(data.Dataset):
         print('==============================')
 
         # load cache data
-        if load_cache:
+        if load_cache is not None:
             self._load_cache()
 
-
+    # ------------ Basic dataset function ------------
     def __getitem__(self, index):
         image, target, deltas = self.pull_item(index)
         return image, target, deltas
 
-
     def __len__(self):
-        return len(self.ids)
-
+        return self.dataset_size
 
     def _load_cache(self):
         # load image cache
-        self.cached_images = []
-        self.cached_targets = []
-        dataset_size = len(self.ids)
-
-        print('loading data into memory ...')
-        for i in range(dataset_size):
-            if i % 5000 == 0:
-                print("[{} / {}]".format(i, dataset_size))
-            # load an image
-            image, image_id = self.pull_image(i)
-            orig_h, orig_w, _ = image.shape
-
-            # resize image
-            r = self.img_size / max(orig_h, orig_w)
-            if r != 1: 
-                interp = cv2.INTER_LINEAR
-                new_size = (int(orig_w * r), int(orig_h * r))
-                image = cv2.resize(image, new_size, interpolation=interp)
-            img_h, img_w = image.shape[:2]
-            self.cached_images.append(image)
-
-            # load target cache
-            anno = ET.parse(self._annopath % image_id).getroot()
-            anno = self.target_transform(anno)
-            anno = np.array(anno).reshape(-1, 5)
-            boxes = anno[:, :4]
-            labels = anno[:, 4]
-            boxes[:, [0, 2]] = boxes[:, [0, 2]] / orig_w * img_w
-            boxes[:, [1, 3]] = boxes[:, [1, 3]] / orig_h * img_h
-            self.cached_targets.append({"boxes": boxes, "labels": labels})
-        
-
-    def load_image_target(self, index):
-        if self.load_cache:
-            image = self.cached_images[index]
-            target = self.cached_targets[index]
-            height, width, channels = image.shape
-            target["orig_size"] = [height, width]
-        else:
-            # load an image
-            img_id = self.ids[index]
-            image = cv2.imread(self._imgpath % img_id)
-            height, width, channels = image.shape
-
-            # laod an annotation
-            anno = ET.parse(self._annopath % img_id).getroot()
-            if self.target_transform is not None:
-                anno = self.target_transform(anno)
-
-            # guard against no boxes via resizing
-            anno = np.array(anno).reshape(-1, 5)
-            target = {
-                "boxes": anno[:, :4],
-                "labels": anno[:, 4],
-                "orig_size": [height, width]
-            }
-        
-        return image, target
-
-
+        try:
+            print("Loading cached data ...")
+            self.cached_datas = torch.load(self.load_cache)
+            self.dataset_size = len(self.cached_datas)
+            print("Loading done !")
+        except:
+            self.load_cache = None
+            print("{} does not exits.".format(self.load_cache))
+
+    # ------------ Mosaic & Mixup ------------
     def load_mosaic(self, index):
         # load 4x mosaic image
         index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
@@ -222,7 +148,6 @@ class VOCDetection(data.Dataset):
 
         return image, target
 
-
     def load_mixup(self, origin_image, origin_target):
         # YOLOv5 type Mixup
         if self.trans_config['mixup_type'] == 'yolov5_mixup':
@@ -239,6 +164,32 @@ class VOCDetection(data.Dataset):
 
         return image, target
     
+    # ------------ Load data function ------------
+    def load_image_target(self, index):
+        # == Load a data from the cached data ==
+        if self.load_cache and self.is_train:
+            # load a data
+            data_item = self.cached_datas[index]
+            image = data_item["image"]
+            target = data_item["target"]
+        # == Load a data from the local disk ==
+        else:        
+            # load an image
+            image, _ = self.pull_image(index)
+            height, width, channels = image.shape
+
+            # laod an annotation
+            anno, _ = self.pull_anno(index)
+
+            # guard against no boxes via resizing
+            anno = np.array(anno).reshape(-1, 5)
+            target = {
+                "boxes": anno[:, :4],
+                "labels": anno[:, 4],
+                "orig_size": [height, width]
+            }
+        
+        return image, target
 
     def pull_item(self, index):
         if random.random() < self.mosaic_prob:
@@ -259,34 +210,18 @@ class VOCDetection(data.Dataset):
 
         return image, target, deltas
 
-
     def pull_image(self, index):
-        '''Returns the original image object at index in PIL form
-        Note: not using self.__getitem__(), as any transformations passed in
-        could mess up this functionality.
-        Argument:
-            index (int): index of img to show
-        Return:
-            PIL img
-        '''
         img_id = self.ids[index]
-        return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR), img_id
+        image = cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
 
+        return image, img_id
 
     def pull_anno(self, index):
-        '''Returns the original annotation of image at index
-        Note: not using self.__getitem__(), as any transformations passed in
-        could mess up this functionality.
-        Argument:
-            index (int): index of img to get annotation of
-        Return:
-            list:  [img_id, [(label, bbox coords),...]]
-                eg: ('001718', [('dog', (96, 13, 438, 332))])
-        '''
         img_id = self.ids[index]
         anno = ET.parse(self._annopath % img_id).getroot()
-        gt = self.target_transform(anno, 1, 1)
-        return img_id[1], gt
+        anno = self.target_transform(anno)
+
+        return anno, img_id
 
 
 if __name__ == "__main__":
@@ -306,8 +241,8 @@ if __name__ == "__main__":
                         help='mixup augmentation.')
     parser.add_argument('--is_train', action="store_true", default=False,
                         help='mixup augmentation.')
-    parser.add_argument('--load_cache', action="store_true", default=False,
-                        help='load cached data.')
+    parser.add_argument('--load_cache', type=str, default=None,
+                        help='Path to the cached data.')
     
     args = parser.parse_args()
 
@@ -324,17 +259,18 @@ if __name__ == "__main__":
         'hsv_v': 0.4,
         'use_ablu': True,
         # Mosaic & Mixup
-        'mosaic_prob': 1.0,
-        'mixup_prob': 1.0,
+        'mosaic_prob': args.mosaic,
+        'mixup_prob': args.mixup,
         'mosaic_type': 'yolov5_mosaic',
         'mixup_type': 'yolov5_mixup',
         'mixup_scale': [0.5, 1.5]
     }
     transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
 
-    dataset = VOCDetection(
+    dataset = VOCDataset(
         img_size=args.img_size,
         data_dir=args.root,
+        image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if args.is_train else [('2007', 'test')],
         trans_config=trans_config,
         transform=transform,
         is_train=args.is_train,