| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- import argparse
- import sys
- import numpy as np
- import torch.nn as nn
- from torch.autograd import Variable
- from torch.optim import SGD
- from torch.utils.data import DataLoader
- from util.util import enumerateWithEstimate
- from .dsets import LunaDataset, getCtSampleSize
- from util.logconf import logging
- # from .model import LunaModel
- log = logging.getLogger(__name__)
- # log.setLevel(logging.WARN)
- log.setLevel(logging.INFO)
- # log.setLevel(logging.DEBUG)
- class LunaPrepCacheApp(object):
- @classmethod
- def __init__(self, sys_argv=None):
- if sys_argv is None:
- sys_argv = sys.argv[1:]
- parser = argparse.ArgumentParser()
- parser.add_argument('--batch-size',
- help='Batch size to use for training',
- default=1024,
- type=int,
- )
- parser.add_argument('--num-workers',
- help='Number of worker processes for background data loading',
- default=8,
- type=int,
- )
- # parser.add_argument('--scaled',
- # help="Scale the CT chunks to square voxels.",
- # default=False,
- # action='store_true',
- # )
- self.cli_args = parser.parse_args(sys_argv)
- def main(self):
- log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
- self.prep_dl = DataLoader(
- LunaDataset(
- sortby_str='series_uid',
- ),
- batch_size=self.cli_args.batch_size,
- num_workers=self.cli_args.num_workers,
- )
- batch_iter = enumerateWithEstimate(
- self.prep_dl,
- "Stuffing cache",
- start_ndx=self.prep_dl.num_workers,
- )
- for batch_ndx, batch_tup in batch_iter:
- _nodule_tensor, _malignant_tensor, series_list, _center_list = batch_tup
- for series_uid in sorted(set(series_list)):
- getCtSampleSize(series_uid)
- # input_tensor, label_tensor, _series_list, _start_list = batch_tup
- if __name__ == '__main__':
- sys.exit(LunaPrepCacheApp().main() or 0)
|