__init__.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import torch
  2. from .cifar import CifarDataset
  3. from .custom import CustomDataset
  4. def build_dataset(args, is_train=False):
  5. # ----------------- CIFAR dataset -----------------
  6. if args.dataset == 'cifar10':
  7. args.num_classes = 10
  8. args.img_dim = 3
  9. args.img_size = 32
  10. args.patch_size = 4
  11. return CifarDataset(is_train)
  12. # ----------------- Customed dataset -----------------
  13. elif args.dataset == 'custom':
  14. assert args.num_classes is not None and isinstance(args.num_classes, int)
  15. args.img_size = 224
  16. args.patch_size = 16
  17. return CustomDataset(args, is_train)
  18. else:
  19. print("Unknown dataset: {}".format(args.dataset))
  20. def build_dataloader(args, dataset, is_train=False):
  21. if is_train:
  22. sampler = torch.utils.data.RandomSampler(dataset)
  23. batch_sampler_train = torch.utils.data.BatchSampler(
  24. sampler, args.batch_size, drop_last=True if is_train else False)
  25. dataloader = torch.utils.data.DataLoader(
  26. dataset, batch_sampler=batch_sampler_train, num_workers=args.num_workers, pin_memory=True)
  27. else:
  28. dataloader = torch.utils.data.DataLoader(
  29. dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
  30. return dataloader