| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- import os
- import PIL
- import numpy as np
- from timm.data import create_transform
- import torch.utils.data as data
- import torchvision.transforms as T
- from torchvision.datasets import ImageFolder
- import torchvision.transforms as transforms
- class ImageNet1KDataset(data.Dataset):
- def __init__(self, args, is_train=False, transform=None):
- super().__init__()
- # ----------------- basic parameters -----------------
- self.args = args
- self.is_train = is_train
- self.pixel_mean = [0.485, 0.456, 0.406]
- self.pixel_std = [0.229, 0.224, 0.225]
- print("Pixel mean: {}".format(self.pixel_mean))
- print("Pixel std: {}".format(self.pixel_std))
- self.image_set = 'train' if is_train else 'val'
- self.data_path = os.path.join(args.root, self.image_set)
- # ----------------- dataset & transforms -----------------
- self.transform = transform if transform is not None else self.build_transform(args)
- self.dataset = ImageFolder(root=self.data_path, transform=self.transform)
- def __len__(self):
- return len(self.dataset)
-
- def __getitem__(self, index):
- image, target = self.dataset[index]
- return image, target
-
- def pull_image(self, index):
- # laod data
- image, target = self.dataset[index]
- # denormalize image
- image = image.permute(1, 2, 0).numpy()
- image = (image * self.pixel_std + self.pixel_mean) * 255.
- image = image.astype(np.uint8)
- image = image.copy()
- return image, target
- def build_transform(self, args):
- if self.is_train:
- transforms = create_transform(input_size = args.img_size,
- is_training = True,
- color_jitter = args.color_jitter,
- auto_augment = args.aa,
- interpolation = 'bicubic',
- re_prob = args.reprob,
- re_mode = args.remode,
- re_count = args.recount,
- mean = self.pixel_mean,
- std = self.pixel_std,
- )
- else:
- t = []
- if args.img_size <= 224:
- crop_pct = 224 / 256
- else:
- crop_pct = 1.0
- size = int(args.img_size / crop_pct)
- t.append(
- T.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
- )
- t.append(T.CenterCrop(args.img_size))
- t.append(T.ToTensor())
- t.append(T.Normalize(self.pixel_mean, self.pixel_std))
- transforms = T.Compose(t)
- return transforms
- if __name__ == "__main__":
- import cv2
- import torch
- import argparse
-
- parser = argparse.ArgumentParser(description='ImageNet-Dataset')
- # opt
- parser.add_argument('--root', default='/mnt/share/ssd2/dataset/imagenet/',
- help='data root')
- parser.add_argument('--img_size', default=224, type=int,
- help='input image size.')
- args = parser.parse_args()
- # Transforms
- train_transform = transforms.Compose([
- transforms.RandomResizedCrop(args.img_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
-
- # Dataset
- dataset = ImageNet1KDataset(args, is_train=True)
- print('Dataset size: ', len(dataset))
- for i in range(1000):
- image, target = dataset.pull_image(i)
- # to BGR
- image = image[..., (2, 1, 0)]
- cv2.imshow('image', image)
- cv2.waitKey(0)
|