| 1234567891011121314151617181920212223242526272829303132333435363738 |
- import torch
- from .cifar import CifarDataset
- from .custom import CustomDataset
- def build_dataset(args, is_train=False):
- # ----------------- CIFAR dataset -----------------
- if args.dataset == 'cifar10':
- args.num_classes = 10
- args.img_dim = 3
- args.img_size = 32
- args.patch_size = 4
- return CifarDataset(is_train)
-
- # ----------------- Customed dataset -----------------
- elif args.dataset == 'custom':
- assert args.num_classes is not None and isinstance(args.num_classes, int)
- args.img_size = 224
- args.patch_size = 16
- return CustomDataset(args, is_train)
-
- else:
- print("Unknown dataset: {}".format(args.dataset))
-
- def build_dataloader(args, dataset, is_train=False):
- if is_train:
- sampler = torch.utils.data.RandomSampler(dataset)
- batch_sampler_train = torch.utils.data.BatchSampler(
- sampler, args.batch_size, drop_last=True if is_train else False)
- dataloader = torch.utils.data.DataLoader(
- dataset, batch_sampler=batch_sampler_train, num_workers=args.num_workers, pin_memory=True)
- else:
- dataloader = torch.utils.data.DataLoader(
- dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
- return dataloader
|