custom.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import os
  2. import PIL
  3. import numpy as np
  4. import torch.utils.data as data
  5. import torchvision.transforms as T
  6. from torchvision.datasets import ImageFolder
  7. class CustomDataset(data.Dataset):
  8. def __init__(self, args, is_train=False):
  9. super().__init__()
  10. # ----------------- basic parameters -----------------
  11. self.args = args
  12. self.is_train = is_train
  13. self.pixel_mean = [0.485, 0.456, 0.406]
  14. self.pixel_std = [0.229, 0.224, 0.225]
  15. print("Pixel mean: {}".format(self.pixel_mean))
  16. print("Pixel std: {}".format(self.pixel_std))
  17. self.image_set = 'train' if is_train else 'val'
  18. self.data_path = os.path.join(args.root, self.image_set)
  19. # ----------------- dataset & transforms -----------------
  20. self.transform = self.build_transform()
  21. self.dataset = ImageFolder(root=self.data_path, transform=self.transform)
  22. def __len__(self):
  23. return len(self.dataset)
  24. def __getitem__(self, index):
  25. image, target = self.dataset[index]
  26. return image, target
  27. def pull_image(self, index):
  28. # laod data
  29. image, target = self.dataset[index]
  30. # denormalize image
  31. image = image.permute(1, 2, 0).numpy()
  32. image = (image * self.pixel_std + self.pixel_mean) * 255.
  33. image = image.astype(np.uint8)
  34. image = image.copy()
  35. return image, target
  36. def build_transform(self):
  37. if self.is_train:
  38. transforms = T.Compose([
  39. T.RandomResizedCrop(224),
  40. T.RandomHorizontalFlip(0.5),
  41. T.ToTensor(),
  42. T.Normalize(self.pixel_mean,
  43. self.pixel_std)])
  44. else:
  45. transforms = T.Compose([
  46. T.Resize(224, interpolation=PIL.Image.BICUBIC),
  47. T.CenterCrop(224),
  48. T.ToTensor(),
  49. T.Normalize(self.pixel_mean, self.pixel_std),
  50. ])
  51. return transforms
  52. if __name__ == "__main__":
  53. import cv2
  54. import argparse
  55. parser = argparse.ArgumentParser(description='Custom-Dataset')
  56. # opt
  57. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/classification/dataset/Animals/',
  58. help='data root')
  59. parser.add_argument('--img_size', default=224, type=int,
  60. help='input image size.')
  61. args = parser.parse_args()
  62. # Dataset
  63. dataset = CustomDataset(args, is_train=True)
  64. print('Dataset size: ', len(dataset))
  65. for i in range(len(dataset)):
  66. image, target = dataset.pull_image(i)
  67. # to BGR
  68. image = image[..., (2, 1, 0)]
  69. cv2.imshow('image', image)
  70. cv2.waitKey(0)