yjh0410 2 years ago
parent
commit
7d261c14bd
3 changed files with 114 additions and 24 deletions
  1. 56 12
      dataset/coco.py
  2. 57 11
      dataset/ourdataset.py
  3. 1 1
      dataset/voc.py

+ 56 - 12
dataset/coco.py

@@ -49,7 +49,8 @@ class COCODataset(Dataset):
                  image_set='train2017',
                  trans_config=None,
                  transform=None,
-                 is_train=False):
+                 is_train=False,
+                 load_cache=False):
         """
         COCO dataset initialization. Annotation data are read into memory by COCO API.
         Args:
@@ -82,6 +83,10 @@ class COCODataset(Dataset):
         print('use Mixup Augmentation: {}'.format(self.mixup_prob))
         print('==============================')
         
+        # load cache data
+        if load_cache:
+            self._load_cache()
+
 
     def __len__(self):
         return len(self.ids)
@@ -91,19 +96,55 @@ class COCODataset(Dataset):
         return self.pull_item(index)
 
 
+    def _load_cache(self):
+        # load image cache
+        self.cached_images = []
+        self.cached_targets = []
+        dataset_size = len(self.ids)
+
+        for i in range(dataset_size):
+            if i % 5000 == 0:
+                print("[{} / {}]".format(i, dataset_size))
+            # load an image
+            image, image_id = self.pull_image(i)
+            orig_h, orig_w, _ = image.shape
+
+            # resize image
+            r = args.img_size / max(orig_h, orig_w)
+            if r != 1: 
+                interp = cv2.INTER_LINEAR
+                new_size = (int(orig_w * r), int(orig_h * r))
+                image = cv2.resize(image, new_size, interpolation=interp)
+            img_h, img_w = image.shape[:2]
+            self.cached_images.append(image)
+
+            # load target cache
+            bboxes, labels = self.pull_anno(i)
+            bboxes[:, [0, 2]] = bboxes[:, [0, 2]] / orig_w * img_w
+            bboxes[:, [1, 3]] = bboxes[:, [1, 3]] / orig_h * img_h
+            self.cached_targets.append({"boxes": bboxes, "labels": labels})
+        
+
     def load_image_target(self, index):
-        # load an image
-        image, _ = self.pull_image(index)
-        height, width, channels = image.shape
+        if self.load_cache:
+            # load data from cache
+            image = self.cached_images[index]
+            target = self.cached_targets[index]
+            height, width, channels = image.shape
+            target["orig_size"] = [height, width]
+        else:
+            # load an image
+            image, _ = self.pull_image(index)
+            height, width, channels = image.shape
 
-        # load a target
-        bboxes, labels = self.pull_anno(index)
+            # load a target
+            bboxes, labels = self.pull_anno(index)
 
-        target = {
-            "boxes": bboxes,
-            "labels": labels,
-            "orig_size": [height, width]
-        }
+            target = {
+                "boxes": bboxes,
+                "labels": labels,
+                "orig_size": [height, width]
+            }
 
         return image, target
 
@@ -236,6 +277,8 @@ if __name__ == "__main__":
                         help='mixup augmentation.')
     parser.add_argument('--is_train', action="store_true", default=False,
                         help='mixup augmentation.')
+    parser.add_argument('--load_cache', action="store_true", default=False,
+                        help='load cached data.')
     
     args = parser.parse_args()
 
@@ -266,7 +309,8 @@ if __name__ == "__main__":
         image_set='val2017',
         trans_config=trans_config,
         transform=transform,
-        is_train=args.is_train
+        is_train=args.is_train,
+        load_cache=args.load_cache
         )
     
     np.random.seed(0)

+ 57 - 11
dataset/ourdataset.py

@@ -31,7 +31,8 @@ class OurDataset(Dataset):
                  image_set='train',
                  transform=None,
                  trans_config=None,
-                 is_train=False):
+                 is_train=False,
+                 load_cache=False):
         """
         COCO dataset initialization. Annotation data are read into memory by COCO API.
         Args:
@@ -48,6 +49,7 @@ class OurDataset(Dataset):
         self.ids = self.coco.getImgIds()
         self.class_ids = sorted(self.coco.getCatIds())
         self.is_train = is_train
+        self.load_cache = load_cache
 
         # augmentation
         self.transform = transform
@@ -65,6 +67,10 @@ class OurDataset(Dataset):
         print('use Mixup Augmentation: {}'.format(self.mixup_prob))
         print('==============================')
 
+        # load cache data
+        if load_cache:
+            self._load_cache()
+
 
     def __len__(self):
         return len(self.ids)
@@ -74,19 +80,55 @@ class OurDataset(Dataset):
         return self.pull_item(index)
 
 
+    def _load_cache(self):
+        # load image cache
+        self.cached_images = []
+        self.cached_targets = []
+        dataset_size = len(self.ids)
+
+        for i in range(dataset_size):
+            if i % 5000 == 0:
+                print("[{} / {}]".format(i, dataset_size))
+            # load an image
+            image, image_id = self.pull_image(i)
+            orig_h, orig_w, _ = image.shape
+
+            # resize image
+            r = args.img_size / max(orig_h, orig_w)
+            if r != 1: 
+                interp = cv2.INTER_LINEAR
+                new_size = (int(orig_w * r), int(orig_h * r))
+                image = cv2.resize(image, new_size, interpolation=interp)
+            img_h, img_w = image.shape[:2]
+            self.cached_images.append(image)
+
+            # load target cache
+            bboxes, labels = self.pull_anno(i)
+            bboxes[:, [0, 2]] = bboxes[:, [0, 2]] / orig_w * img_w
+            bboxes[:, [1, 3]] = bboxes[:, [1, 3]] / orig_h * img_h
+            self.cached_targets.append({"boxes": bboxes, "labels": labels})
+        
+
     def load_image_target(self, index):
-        # load an image
-        image, _ = self.pull_image(index)
-        height, width, channels = image.shape
+        if self.load_cache:
+            # load data from cache
+            image = self.cached_images[index]
+            target = self.cached_targets[index]
+            height, width, channels = image.shape
+            target["orig_size"] = [height, width]
+        else:
+            # load an image
+            image, _ = self.pull_image(index)
+            height, width, channels = image.shape
 
-        # load a target
-        bboxes, labels = self.pull_anno(index)
+            # load a target
+            bboxes, labels = self.pull_anno(index)
 
-        target = {
-            "boxes": bboxes,
-            "labels": labels,
-            "orig_size": [height, width]
-        }
+            target = {
+                "boxes": bboxes,
+                "labels": labels,
+                "orig_size": [height, width]
+            }
 
         return image, target
 
@@ -217,6 +259,8 @@ if __name__ == "__main__":
                         help='mixup augmentation.')
     parser.add_argument('--is_train', action="store_true", default=False,
                         help='mixup augmentation.')
+    parser.add_argument('--load_cache', action="store_true", default=False,
+                        help='load cached data.')
     
     args = parser.parse_args()
 
@@ -247,6 +291,8 @@ if __name__ == "__main__":
         image_set=args.split,
         transform=transform,
         trans_config=trans_config,
+        is_train=args.is_train,
+        load_cache=args.load_cache
         )
     
     np.random.seed(0)

+ 1 - 1
dataset/voc.py

@@ -145,7 +145,7 @@ class VOCDetection(data.Dataset):
         dataset_size = len(self.ids)
 
         for i in range(dataset_size):
-            if (i+1) % 5000 == 0:
+            if i % 5000 == 0:
                 print("[{} / {}]".format(i, dataset_size))
             # load an image
             image, image_id = self.pull_image(i)