screencts.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import argparse
  2. import sys
  3. import numpy as np
  4. import torch.nn as nn
  5. from torch.autograd import Variable
  6. from torch.optim import SGD
  7. from torch.utils.data import Dataset, DataLoader
  8. from util.util import enumerateWithEstimate, prhist
  9. from .dsets import getNoduleInfoList, getCtSize, getCt
  10. from util.logconf import logging
  11. # from .model import LunaModel
  12. log = logging.getLogger(__name__)
  13. # log.setLevel(logging.WARN)
  14. log.setLevel(logging.INFO)
  15. # log.setLevel(logging.DEBUG)
  16. class LunaScreenCtDataset(Dataset):
  17. def __init__(self):
  18. self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
  19. def __len__(self):
  20. return len(self.series_list)
  21. def __getitem__(self, ndx):
  22. series_uid = self.series_list[ndx]
  23. ct = getCt(series_uid)
  24. mid_ndx = ct.ary.shape[0] // 2
  25. air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask, altben_mask = ct.build2dLungMask(mid_ndx)
  26. return series_uid, float(dense_mask.sum() / denoise_mask.sum())
  27. class LunaScreenCtApp(object):
  28. @classmethod
  29. def __init__(self, sys_argv=None):
  30. if sys_argv is None:
  31. sys_argv = sys.argv[1:]
  32. parser = argparse.ArgumentParser()
  33. parser.add_argument('--batch-size',
  34. help='Batch size to use for training',
  35. default=4,
  36. type=int,
  37. )
  38. parser.add_argument('--num-workers',
  39. help='Number of worker processes for background data loading',
  40. default=8,
  41. type=int,
  42. )
  43. # parser.add_argument('--scaled',
  44. # help="Scale the CT chunks to square voxels.",
  45. # default=False,
  46. # action='store_true',
  47. # )
  48. self.cli_args = parser.parse_args(sys_argv)
  49. def main(self):
  50. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  51. self.prep_dl = DataLoader(
  52. LunaScreenCtDataset(),
  53. batch_size=self.cli_args.batch_size,
  54. num_workers=self.cli_args.num_workers,
  55. )
  56. series2ratio_dict = {}
  57. batch_iter = enumerateWithEstimate(
  58. self.prep_dl,
  59. "Screening CTs",
  60. start_ndx=self.prep_dl.num_workers,
  61. )
  62. for batch_ndx, batch_tup in batch_iter:
  63. series_list, ratio_list = batch_tup
  64. for series_uid, ratio_float in zip(series_list, ratio_list):
  65. series2ratio_dict[series_uid] = ratio_float
  66. # break
  67. prhist(list(series2ratio_dict.values()))
  68. if __name__ == '__main__':
  69. sys.exit(LunaScreenCtApp().main() or 0)