Browse Source

debug cache

yjh0410 2 years ago
parent
commit
d6cfe501c7
1 changed files with 30 additions and 9 deletions
  1. 30 9
      dataset/voc.py

+ 30 - 9
dataset/voc.py

@@ -140,21 +140,41 @@ class VOCDetection(data.Dataset):
 
     def _load_cache(self):
         # load image cache
-        self.image_list = None  # TODO: H5PY file
+        self.cached_images = []
+        self.cached_targets = []
+        dataset_size = len(self.ids)
 
-        # load target cache
-        self.target_list = []
-        for img_id in self.ids:
-            anno = ET.parse(self._annopath % img_id).getroot()
+        for i in range(dataset_size):
+            if (i+1) % 5000 == 0:
+                print("[{} / {}]".format(i, dataset_size))
+            # load an image
+            image, image_id = self.pull_image(i)
+            orig_h, orig_w, _ = image.shape
+
+            # resize image
+            r = args.img_size / max(orig_h, orig_w)
+            if r != 1: 
+                interp = cv2.INTER_LINEAR
+                new_size = (int(orig_w * r), int(orig_h * r))
+                image = cv2.resize(image, new_size, interpolation=interp)
+            img_h, img_w = image.shape[:2]
+            self.cached_images.append(image)
+
+            # load target cache
+            anno = ET.parse(self._annopath % image_id).getroot()
             anno = self.target_transform(anno)
             anno = np.array(anno).reshape(-1, 5)
-            self.target_list.append({"boxes": anno[:, :4], "labels": anno[:, 4]})
+            boxes = anno[:, :4]
+            labels = anno[:, 4]
+            boxes[:, [0, 2]] = boxes[:, [0, 2]] / orig_w * img_w
+            boxes[:, [1, 3]] = boxes[:, [1, 3]] / orig_h * img_h
+            self.cached_targets.append({"boxes": boxes, "labels": labels})
         
 
     def load_image_target(self, index):
         if self.load_cache:
-            image = self.image_list[index]
-            target = self.target_list[index]
+            image = self.cached_images[index]
+            target = self.cached_targets[index]
             height, width, channels = image.shape
             target["orig_size"] = [height, width]
         else:
@@ -315,7 +335,8 @@ if __name__ == "__main__":
         data_dir=args.root,
         trans_config=trans_config,
         transform=transform,
-        is_train=args.is_train
+        is_train=args.is_train,
+        load_cache=args.load_cache
         )
     
     np.random.seed(0)