prepcache.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import argparse
  2. import sys
  3. import numpy as np
  4. import torch.nn as nn
  5. from torch.autograd import Variable
  6. from torch.optim import SGD
  7. from torch.utils.data import DataLoader
  8. from util.util import enumerateWithEstimate
  9. from .dsets import LunaDataset
  10. from util.logconf import logging
  11. from .model import LunaModel
  12. log = logging.getLogger(__name__)
  13. # log.setLevel(logging.WARN)
  14. log.setLevel(logging.INFO)
  15. # log.setLevel(logging.DEBUG)
  16. class LunaPrepCacheApp(object):
  17. @classmethod
  18. def __init__(self, sys_argv=None):
  19. if sys_argv is None:
  20. sys_argv = sys.argv[1:]
  21. parser = argparse.ArgumentParser()
  22. parser.add_argument('--batch-size',
  23. help='Batch size to use for training',
  24. default=1024,
  25. type=int,
  26. )
  27. parser.add_argument('--num-workers',
  28. help='Number of worker processes for background data loading',
  29. default=8,
  30. type=int,
  31. )
  32. self.cli_args = parser.parse_args(sys_argv)
  33. def main(self):
  34. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  35. self.prep_dl = DataLoader(
  36. LunaDataset(
  37. sortby_str='series_uid',
  38. ),
  39. batch_size=self.cli_args.batch_size,
  40. num_workers=self.cli_args.num_workers,
  41. )
  42. batch_iter = enumerateWithEstimate(
  43. self.prep_dl,
  44. "Stuffing cache",
  45. start_ndx=self.prep_dl.num_workers,
  46. )
  47. for _ in batch_iter:
  48. pass
  49. if __name__ == '__main__':
  50. LunaPrepCacheApp().main()