__init__.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import torch
  2. from .cifar import CifarDataset
  3. from .mnist import MnistDataset
  4. from .custom import CustomDataset
  5. def build_dataset(args, is_train=False):
  6. if args.dataset == 'cifar10':
  7. args.img_dim = 3
  8. args.img_size = 32
  9. args.mlp_in_dim = 32 * 32 * 3
  10. args.num_classes = 10
  11. args.patch_size = 4
  12. return CifarDataset(is_train)
  13. elif args.dataset == 'mnist':
  14. args.img_dim = 1
  15. args.img_size = 28
  16. args.mlp_in_dim = 28 * 28 * 1
  17. args.num_classes = 10
  18. args.patch_size = 4
  19. return MnistDataset(is_train)
  20. elif args.dataset == 'custom':
  21. assert args.num_classes is not None and isinstance(args.num_classes, int)
  22. args.img_size = 224
  23. args.mlp_in_dim = 224 * 224 * 3
  24. args.patch_size = 16
  25. return CustomDataset(args, is_train)
  26. def build_dataloader(args, dataset, is_train=False):
  27. if is_train:
  28. sampler = torch.utils.data.RandomSampler(dataset)
  29. batch_sampler_train = torch.utils.data.BatchSampler(
  30. sampler, args.batch_size, drop_last=True if is_train else False)
  31. dataloader = torch.utils.data.DataLoader(
  32. dataset, batch_sampler=batch_sampler_train, num_workers=args.num_workers, pin_memory=True)
  33. else:
  34. dataloader = torch.utils.data.DataLoader(
  35. dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
  36. return dataloader