| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- import argparse
- import sys
- import numpy as np
- import torch.nn as nn
- from torch.autograd import Variable
- from torch.optim import SGD
- from torch.utils.data import DataLoader
- from util.util import enumerateWithEstimate
- from .dsets import LunaDataset
- from util.logconf import logging
- from .model import LunaModel
- log = logging.getLogger(__name__)
- # log.setLevel(logging.WARN)
- log.setLevel(logging.INFO)
- # log.setLevel(logging.DEBUG)
- class LunaPrepCacheApp:
- @classmethod
- def __init__(self, sys_argv=None):
- if sys_argv is None:
- sys_argv = sys.argv[1:]
- parser = argparse.ArgumentParser()
- parser.add_argument('--batch-size',
- help='Batch size to use for training',
- default=1024,
- type=int,
- )
- parser.add_argument('--num-workers',
- help='Number of worker processes for background data loading',
- default=8,
- type=int,
- )
- self.cli_args = parser.parse_args(sys_argv)
- def main(self):
- log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
- self.prep_dl = DataLoader(
- LunaDataset(
- sortby_str='series_uid',
- ),
- batch_size=self.cli_args.batch_size,
- num_workers=self.cli_args.num_workers,
- )
- batch_iter = enumerateWithEstimate(
- self.prep_dl,
- "Stuffing cache",
- start_ndx=self.prep_dl.num_workers,
- )
- for _ in batch_iter:
- pass
- if __name__ == '__main__':
- LunaPrepCacheApp().main()
|