|
|
@@ -140,21 +140,41 @@ class VOCDetection(data.Dataset):
|
|
|
|
|
|
def _load_cache(self):
|
|
|
# load image cache
|
|
|
- self.image_list = None # TODO: H5PY file
|
|
|
+ self.cached_images = []
|
|
|
+ self.cached_targets = []
|
|
|
+ dataset_size = len(self.ids)
|
|
|
|
|
|
- # load target cache
|
|
|
- self.target_list = []
|
|
|
- for img_id in self.ids:
|
|
|
- anno = ET.parse(self._annopath % img_id).getroot()
|
|
|
+ for i in range(dataset_size):
|
|
|
+ if (i+1) % 5000 == 0:
|
|
|
+ print("[{} / {}]".format(i, dataset_size))
|
|
|
+ # load an image
|
|
|
+ image, image_id = self.pull_image(i)
|
|
|
+ orig_h, orig_w, _ = image.shape
|
|
|
+
|
|
|
+ # resize image
|
|
|
+ r = args.img_size / max(orig_h, orig_w)
|
|
|
+ if r != 1:
|
|
|
+ interp = cv2.INTER_LINEAR
|
|
|
+ new_size = (int(orig_w * r), int(orig_h * r))
|
|
|
+ image = cv2.resize(image, new_size, interpolation=interp)
|
|
|
+ img_h, img_w = image.shape[:2]
|
|
|
+ self.cached_images.append(image)
|
|
|
+
|
|
|
+ # load target cache
|
|
|
+ anno = ET.parse(self._annopath % image_id).getroot()
|
|
|
anno = self.target_transform(anno)
|
|
|
anno = np.array(anno).reshape(-1, 5)
|
|
|
- self.target_list.append({"boxes": anno[:, :4], "labels": anno[:, 4]})
|
|
|
+ boxes = anno[:, :4]
|
|
|
+ labels = anno[:, 4]
|
|
|
+ boxes[:, [0, 2]] = boxes[:, [0, 2]] / orig_w * img_w
|
|
|
+ boxes[:, [1, 3]] = boxes[:, [1, 3]] / orig_h * img_h
|
|
|
+ self.cached_targets.append({"boxes": boxes, "labels": labels})
|
|
|
|
|
|
|
|
|
def load_image_target(self, index):
|
|
|
if self.load_cache:
|
|
|
- image = self.image_list[index]
|
|
|
- target = self.target_list[index]
|
|
|
+ image = self.cached_images[index]
|
|
|
+ target = self.cached_targets[index]
|
|
|
height, width, channels = image.shape
|
|
|
target["orig_size"] = [height, width]
|
|
|
else:
|
|
|
@@ -315,7 +335,8 @@ if __name__ == "__main__":
|
|
|
data_dir=args.root,
|
|
|
trans_config=trans_config,
|
|
|
transform=transform,
|
|
|
- is_train=args.is_train
|
|
|
+ is_train=args.is_train,
|
|
|
+ load_cache=args.load_cache
|
|
|
)
|
|
|
|
|
|
np.random.seed(0)
|