__init__.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import torch.utils.data as data
  2. from .cifar import CifarDataset
  3. from .mnist import MnistDataset
  4. from .imagenet import ImageNet1KDataset
  5. from .custom import CustomDataset
  6. def build_dataset(args, transform=None, is_train=False):
  7. if args.dataset == 'cifar10':
  8. args.num_classes = 10
  9. args.img_dim = 3
  10. return CifarDataset(is_train, transform)
  11. elif args.dataset == 'mnist':
  12. args.num_classes = 10
  13. args.img_dim = 1
  14. return MnistDataset(is_train, transform)
  15. elif args.dataset == 'imagenet_1k':
  16. args.num_classes = 1000
  17. args.img_dim = 3
  18. return ImageNet1KDataset(args, is_train, transform)
  19. elif args.dataset == 'custom':
  20. assert args.num_classes is not None and isinstance(args.num_classes, int)
  21. args.img_dim = 3
  22. return CustomDataset(args, is_train, transform)
  23. def build_dataloader(args, dataset, is_train=False):
  24. if is_train:
  25. sampler = data.distributed.DistributedSampler(dataset) if args.distributed else data.RandomSampler(dataset)
  26. batch_sampler_train = data.BatchSampler(sampler, args.batch_size // args.world_size, drop_last=True if is_train else False)
  27. dataloader = data.DataLoader(dataset, batch_sampler=batch_sampler_train, num_workers=args.num_workers, pin_memory=True)
  28. else:
  29. dataloader = data.DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
  30. return dataloader