cifar.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import os
  2. import numpy as np
  3. import torch.utils.data as data
  4. import torchvision.transforms as T
  5. from torchvision.datasets import CIFAR10
  6. class CifarDataset(data.Dataset):
  7. def __init__(self, is_train=False):
  8. super().__init__()
  9. # ----------------- basic parameters -----------------
  10. self.pixel_mean = [0.5, 0.5, 0.5]
  11. self.pixel_std = [0.5, 0.5, 0.5]
  12. self.is_train = is_train
  13. self.image_set = 'train' if is_train else 'val'
  14. # ----------------- dataset & transforms -----------------
  15. self.transform = self.build_transform()
  16. path = os.path.dirname(os.path.abspath(__file__))
  17. if is_train:
  18. self.dataset = CIFAR10(os.path.join(path, 'cifar_data/'), train=True, download=True, transform=self.transform)
  19. else:
  20. self.dataset = CIFAR10(os.path.join(path, 'cifar_data/'), train=False, download=True, transform=self.transform)
  21. def __len__(self):
  22. return len(self.dataset)
  23. def __getitem__(self, index):
  24. image, target = self.dataset[index]
  25. return image, target
  26. def pull_image(self, index):
  27. # laod data
  28. image, target = self.dataset[index]
  29. # denormalize image
  30. image = image.permute(1, 2, 0).numpy()
  31. image = (image * self.pixel_std + self.pixel_mean) * 255.
  32. image = image.astype(np.uint8)
  33. image = image.copy()
  34. return image, target
  35. def build_transform(self):
  36. if self.is_train:
  37. transforms = T.Compose([T.ToTensor(), T.Normalize(0.5, 0.5)])
  38. else:
  39. transforms = T.Compose([T.ToTensor(), T.Normalize(0.5, 0.5)])
  40. return transforms
  41. if __name__ == "__main__":
  42. import cv2
  43. # dataset
  44. dataset = CifarDataset(is_train=True)
  45. print('Dataset size: ', len(dataset))
  46. for i in range(len(dataset)):
  47. image, target = dataset.pull_image(i)
  48. # to BGR
  49. image = image[..., (2, 1, 0)]
  50. cv2.imshow('image', image)
  51. cv2.waitKey(0)