imagenet.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import os
  2. import PIL
  3. import numpy as np
  4. from timm.data import create_transform
  5. import torch.utils.data as data
  6. import torchvision.transforms as T
  7. from torchvision.datasets import ImageFolder
  8. import torchvision.transforms as transforms
  9. class ImageNet1KDataset(data.Dataset):
  10. def __init__(self, args, is_train=False, transform=None):
  11. super().__init__()
  12. # ----------------- basic parameters -----------------
  13. self.args = args
  14. self.is_train = is_train
  15. self.pixel_mean = [0.485, 0.456, 0.406]
  16. self.pixel_std = [0.229, 0.224, 0.225]
  17. print("Pixel mean: {}".format(self.pixel_mean))
  18. print("Pixel std: {}".format(self.pixel_std))
  19. self.image_set = 'train' if is_train else 'val'
  20. self.data_path = os.path.join(args.root, self.image_set)
  21. # ----------------- dataset & transforms -----------------
  22. self.transform = transform if transform is not None else self.build_transform(args)
  23. self.dataset = ImageFolder(root=self.data_path, transform=self.transform)
  24. def __len__(self):
  25. return len(self.dataset)
  26. def __getitem__(self, index):
  27. image, target = self.dataset[index]
  28. return image, target
  29. def pull_image(self, index):
  30. # laod data
  31. image, target = self.dataset[index]
  32. # denormalize image
  33. image = image.permute(1, 2, 0).numpy()
  34. image = (image * self.pixel_std + self.pixel_mean) * 255.
  35. image = image.astype(np.uint8)
  36. image = image.copy()
  37. return image, target
  38. def build_transform(self, args):
  39. if self.is_train:
  40. transforms = create_transform(input_size = args.img_size,
  41. is_training = True,
  42. color_jitter = args.color_jitter,
  43. auto_augment = args.aa,
  44. interpolation = 'bicubic',
  45. re_prob = args.reprob,
  46. re_mode = args.remode,
  47. re_count = args.recount,
  48. mean = self.pixel_mean,
  49. std = self.pixel_std,
  50. )
  51. else:
  52. t = []
  53. if args.img_size <= 224:
  54. crop_pct = 224 / 256
  55. else:
  56. crop_pct = 1.0
  57. size = int(args.img_size / crop_pct)
  58. t.append(
  59. T.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
  60. )
  61. t.append(T.CenterCrop(args.img_size))
  62. t.append(T.ToTensor())
  63. t.append(T.Normalize(self.pixel_mean, self.pixel_std))
  64. transforms = T.Compose(t)
  65. return transforms
  66. if __name__ == "__main__":
  67. import cv2
  68. import torch
  69. import argparse
  70. parser = argparse.ArgumentParser(description='ImageNet-Dataset')
  71. # opt
  72. parser.add_argument('--root', default='/mnt/share/ssd2/dataset/imagenet/',
  73. help='data root')
  74. parser.add_argument('--img_size', default=224, type=int,
  75. help='input image size.')
  76. args = parser.parse_args()
  77. # Transforms
  78. train_transform = transforms.Compose([
  79. transforms.RandomResizedCrop(args.img_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic
  80. transforms.RandomHorizontalFlip(),
  81. transforms.ToTensor(),
  82. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
  83. # Dataset
  84. dataset = ImageNet1KDataset(args, is_train=True)
  85. print('Dataset size: ', len(dataset))
  86. for i in range(1000):
  87. image, target = dataset.pull_image(i)
  88. # to BGR
  89. image = image[..., (2, 1, 0)]
  90. cv2.imshow('image', image)
  91. cv2.waitKey(0)