mnist.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import os
  2. import torch.utils.data as data
  3. import torchvision.transforms as T
  4. from torchvision.datasets import MNIST
  5. class MnistDataset(data.Dataset):
  6. def __init__(self, is_train=False, transform=None):
  7. super().__init__()
  8. # ----------------- basic parameters -----------------
  9. self.is_train = is_train
  10. self.pixel_mean = [0.]
  11. self.pixel_std = [1.]
  12. self.image_set = 'train' if is_train else 'val'
  13. # ----------------- dataset & transforms -----------------
  14. self.transform = self.build_transform()
  15. path = os.path.dirname(os.path.abspath(__file__))
  16. if is_train:
  17. self.dataset = MNIST(os.path.join(path, 'mnist_data/'), train=True, download=True, transform=self.transform)
  18. else:
  19. self.dataset = MNIST(os.path.join(path, 'mnist_data/'), train=False, download=True, transform=self.transform)
  20. def __len__(self):
  21. return len(self.dataset)
  22. def __getitem__(self, index):
  23. image, target = self.dataset[index]
  24. return image, target
  25. def pull_image(self, index):
  26. # laod data
  27. image, target = self.dataset[index]
  28. # denormalize image
  29. image = image.permute(1, 2, 0).numpy()
  30. image = image.copy()
  31. return image, target
  32. def build_transform(self):
  33. if self.is_train:
  34. transforms = T.Compose([T.ToTensor(),])
  35. else:
  36. transforms = T.Compose([T.ToTensor(),])
  37. return transforms
  38. if __name__ == "__main__":
  39. import cv2
  40. # dataset
  41. dataset = MnistDataset(is_train=True)
  42. print('Dataset size: ', len(dataset))
  43. for i in range(1000):
  44. image, target = dataset.pull_image(i)
  45. cv2.imshow('image', image)
  46. cv2.waitKey(0)