prepcache.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import argparse
  2. import sys
  3. import numpy as np
  4. import torch.nn as nn
  5. from torch.autograd import Variable
  6. from torch.optim import SGD
  7. from torch.utils.data import DataLoader
  8. from util.util import enumerateWithEstimate
  9. from .dsets import PrepcacheLunaDataset, getCtSampleSize
  10. from util.logconf import logging
  11. # from .model import LunaModel
  12. log = logging.getLogger(__name__)
  13. # log.setLevel(logging.WARN)
  14. log.setLevel(logging.INFO)
  15. # log.setLevel(logging.DEBUG)
  16. class LunaPrepCacheApp:
  17. @classmethod
  18. def __init__(self, sys_argv=None):
  19. if sys_argv is None:
  20. sys_argv = sys.argv[1:]
  21. parser = argparse.ArgumentParser()
  22. parser.add_argument('--batch-size',
  23. help='Batch size to use for training',
  24. default=1024,
  25. type=int,
  26. )
  27. parser.add_argument('--num-workers',
  28. help='Number of worker processes for background data loading',
  29. default=8,
  30. type=int,
  31. )
  32. # parser.add_argument('--scaled',
  33. # help="Scale the CT chunks to square voxels.",
  34. # default=False,
  35. # action='store_true',
  36. # )
  37. self.cli_args = parser.parse_args(sys_argv)
  38. def main(self):
  39. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  40. self.prep_dl = DataLoader(
  41. PrepcacheLunaDataset(
  42. # sortby_str='series_uid',
  43. ),
  44. batch_size=self.cli_args.batch_size,
  45. num_workers=self.cli_args.num_workers,
  46. )
  47. batch_iter = enumerateWithEstimate(
  48. self.prep_dl,
  49. "Stuffing cache",
  50. start_ndx=self.prep_dl.num_workers,
  51. )
  52. for batch_ndx, batch_tup in batch_iter:
  53. pass
  54. if __name__ == '__main__':
  55. LunaPrepCacheApp().main()