| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- import os
- import numpy as np
- import torch.utils.data as data
- import torchvision.transforms as T
- from torchvision.datasets import CIFAR10
- class CifarDataset(data.Dataset):
- def __init__(self, is_train=False, transform=None):
- super().__init__()
- # ----------------- basic parameters -----------------
- self.is_train = is_train
- self.pixel_mean = [0.0]
- self.pixel_std = [1.0]
- self.image_set = 'train' if is_train else 'val'
- # ----------------- dataset & transforms -----------------
- self.transform = self.build_transform()
- path = os.path.dirname(os.path.abspath(__file__))
- if is_train:
- self.dataset = CIFAR10(os.path.join(path, 'cifar_data/'), train=True, download=True, transform=self.transform)
- else:
- self.dataset = CIFAR10(os.path.join(path, 'cifar_data/'), train=False, download=True, transform=self.transform)
- def __len__(self):
- return len(self.dataset)
-
- def __getitem__(self, index):
- image, target = self.dataset[index]
-
- return image, target
-
- def pull_image(self, index):
- # laod data
- image, target = self.dataset[index]
- # ------- Denormalize image -------
- ## [C, H, W] -> [H, W, C], torch.Tensor -> numpy.adnarry
- image = image.permute(1, 2, 0).numpy()
- ## Denomalize: I = I_n * std + mean, I = I * 255
- image = (image * self.pixel_std + self.pixel_mean) * 255.
- image = image.astype(np.uint8)
- image = image.copy()
- return image, target
- def build_transform(self):
- if self.is_train:
- transforms = T.Compose([T.ToTensor(), T.RandomCrop(size=32, padding=8)])
- else:
- transforms = T.Compose([T.ToTensor()])
- return transforms
- if __name__ == "__main__":
- import cv2
- import argparse
-
- parser = argparse.ArgumentParser(description='Cifar-Dataset')
- # opt
- parser.add_argument('--is_train', action="store_true", default=False,
- help='train or not.')
-
- args = parser.parse_args()
- # dataset
- dataset = CifarDataset(is_train=args.is_train)
- print('Dataset size: ', len(dataset))
- for i in range(1000):
- image, target = dataset.pull_image(i)
- # to BGR
- image = image[..., (2, 1, 0)]
- cv2.imshow('image', image)
- cv2.waitKey(0)
|