cifar.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import os
  2. import numpy as np
  3. import torch.utils.data as data
  4. import torchvision.transforms as T
  5. from torchvision.datasets import CIFAR10
  6. class CifarDataset(data.Dataset):
  7. def __init__(self, is_train=False, transform=None):
  8. super().__init__()
  9. # ----------------- basic parameters -----------------
  10. self.is_train = is_train
  11. self.pixel_mean = [0.0]
  12. self.pixel_std = [1.0]
  13. self.image_set = 'train' if is_train else 'val'
  14. # ----------------- dataset & transforms -----------------
  15. self.transform = self.build_transform()
  16. path = os.path.dirname(os.path.abspath(__file__))
  17. if is_train:
  18. self.dataset = CIFAR10(os.path.join(path, 'cifar_data/'), train=True, download=True, transform=self.transform)
  19. else:
  20. self.dataset = CIFAR10(os.path.join(path, 'cifar_data/'), train=False, download=True, transform=self.transform)
  21. def __len__(self):
  22. return len(self.dataset)
  23. def __getitem__(self, index):
  24. image, target = self.dataset[index]
  25. return image, target
  26. def pull_image(self, index):
  27. # laod data
  28. image, target = self.dataset[index]
  29. # ------- Denormalize image -------
  30. ## [C, H, W] -> [H, W, C], torch.Tensor -> numpy.adnarry
  31. image = image.permute(1, 2, 0).numpy()
  32. ## Denomalize: I = I_n * std + mean, I = I * 255
  33. image = (image * self.pixel_std + self.pixel_mean) * 255.
  34. image = image.astype(np.uint8)
  35. image = image.copy()
  36. return image, target
  37. def build_transform(self):
  38. if self.is_train:
  39. transforms = T.Compose([T.ToTensor(), T.RandomCrop(size=32, padding=8)])
  40. else:
  41. transforms = T.Compose([T.ToTensor()])
  42. return transforms
  43. if __name__ == "__main__":
  44. import cv2
  45. import argparse
  46. parser = argparse.ArgumentParser(description='Cifar-Dataset')
  47. # opt
  48. parser.add_argument('--is_train', action="store_true", default=False,
  49. help='train or not.')
  50. args = parser.parse_args()
  51. # dataset
  52. dataset = CifarDataset(is_train=args.is_train)
  53. print('Dataset size: ', len(dataset))
  54. for i in range(1000):
  55. image, target = dataset.pull_image(i)
  56. # to BGR
  57. image = image[..., (2, 1, 0)]
  58. cv2.imshow('image', image)
  59. cv2.waitKey(0)