diagnose.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. import argparse
  2. import glob
  3. import os
  4. import sys
  5. import numpy as np
  6. import scipy.ndimage.measurements as measure
  7. import scipy.ndimage.morphology as morph
  8. import torch
  9. import torch.nn as nn
  10. import torch.optim
  11. from torch.utils.data import DataLoader
  12. from util.util import enumerateWithEstimate
  13. from .dsets import LunaDataset, Luna2dSegmentationDataset, getCt, getNoduleInfoList, NoduleInfoTuple
  14. from .model_seg import UNetWrapper
  15. from .model_cls import LunaModel, AlternateLunaModel
  16. from util.logconf import logging
  17. from util.util import xyz2irc, irc2xyz
  18. log = logging.getLogger(__name__)
  19. # log.setLevel(logging.WARN)
  20. # log.setLevel(logging.INFO)
  21. log.setLevel(logging.DEBUG)
  22. class LunaDiagnoseApp(object):
  23. def __init__(self, sys_argv=None):
  24. if sys_argv is None:
  25. log.debug(sys.argv)
  26. sys_argv = sys.argv[1:]
  27. parser = argparse.ArgumentParser()
  28. parser.add_argument('--batch-size',
  29. help='Batch size to use for training',
  30. default=4,
  31. type=int,
  32. )
  33. parser.add_argument('--num-workers',
  34. help='Number of worker processes for background data loading',
  35. default=8,
  36. type=int,
  37. )
  38. parser.add_argument('--series-uid',
  39. help='Limit inference to this Series UID only.',
  40. default=None,
  41. type=str,
  42. )
  43. parser.add_argument('--include-train',
  44. help="Include data that was in the training set. (default: validation data only)",
  45. action='store_true',
  46. default=False,
  47. )
  48. parser.add_argument('--segmentation-path',
  49. help="Path to the saved segmentation model",
  50. nargs='?',
  51. default=None,
  52. )
  53. parser.add_argument('--classification-path',
  54. help="Path to the saved classification model",
  55. nargs='?',
  56. default=None,
  57. )
  58. parser.add_argument('--tb-prefix',
  59. default='p2ch12',
  60. help="Data prefix to use for Tensorboard run. Defaults to chapter.",
  61. )
  62. self.cli_args = parser.parse_args(sys_argv)
  63. # self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
  64. self.use_cuda = torch.cuda.is_available()
  65. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  66. if not self.cli_args.segmentation_path:
  67. self.cli_args.segmentation_path = self.initModelPath('seg')
  68. if not self.cli_args.classification_path:
  69. self.cli_args.classification_path = self.initModelPath('cls')
  70. self.seg_model, self.cls_model = self.initModels()
  71. def initModelPath(self, type_str):
  72. local_path = os.path.join(
  73. 'data-unversioned',
  74. 'part2',
  75. 'models',
  76. self.cli_args.tb_prefix,
  77. type_str + '_{}_{}.{}.state'.format('*', '*', 'best'),
  78. )
  79. file_list = glob.glob(local_path)
  80. if not file_list:
  81. pretrained_path = os.path.join(
  82. 'data',
  83. 'part2',
  84. 'models',
  85. type_str + '_{}_{}.{}.state'.format('*', '*', '*'),
  86. )
  87. file_list = glob.glob(pretrained_path)
  88. else:
  89. pretrained_path = None
  90. file_list.sort()
  91. try:
  92. return file_list[-1]
  93. except IndexError:
  94. log.debug([local_path, pretrained_path, file_list])
  95. raise
  96. def initModels(self):
  97. log.debug(self.cli_args.segmentation_path)
  98. seg_dict = torch.load(self.cli_args.segmentation_path)
  99. seg_model = UNetWrapper(in_channels=8, n_classes=1, depth=4, wf=3, padding=True, batch_norm=True, up_mode='upconv')
  100. seg_model.load_state_dict(seg_dict['model_state'])
  101. seg_model.eval()
  102. log.debug(self.cli_args.classification_path)
  103. cls_dict = torch.load(self.cli_args.classification_path)
  104. cls_model = LunaModel()
  105. # cls_model = AlternateLunaModel()
  106. cls_model.load_state_dict(cls_dict['model_state'])
  107. cls_model.eval()
  108. if self.use_cuda:
  109. if torch.cuda.device_count() > 1:
  110. seg_model = nn.DataParallel(seg_model)
  111. cls_model = nn.DataParallel(cls_model)
  112. seg_model = seg_model.to(self.device)
  113. cls_model = cls_model.to(self.device)
  114. return seg_model, cls_model
  115. def initSegmentationDl(self, series_uid):
  116. seg_ds = Luna2dSegmentationDataset(
  117. contextSlices_count=3,
  118. series_uid=series_uid,
  119. fullCt_bool=True,
  120. )
  121. seg_dl = DataLoader(
  122. seg_ds,
  123. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  124. num_workers=self.cli_args.num_workers,
  125. pin_memory=self.use_cuda,
  126. )
  127. return seg_dl
  128. def initClassificationDl(self, noduleInfo_list):
  129. cls_ds = LunaDataset(
  130. sortby_str='series_uid',
  131. noduleInfo_list=noduleInfo_list,
  132. )
  133. cls_dl = DataLoader(
  134. cls_ds,
  135. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  136. num_workers=self.cli_args.num_workers,
  137. pin_memory=self.use_cuda,
  138. )
  139. return cls_dl
  140. def main(self):
  141. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  142. val_ds = LunaDataset(
  143. val_stride=10,
  144. isValSet_bool=True,
  145. )
  146. val_set = set(
  147. noduleInfo_tup.series_uid
  148. for noduleInfo_tup in val_ds.noduleInfo_list
  149. )
  150. malignant_set = set(
  151. noduleInfo_tup.series_uid
  152. for noduleInfo_tup in getNoduleInfoList()
  153. if noduleInfo_tup.isMalignant_bool
  154. )
  155. if self.cli_args.series_uid:
  156. series_set = set(self.cli_args.series_uid.split(','))
  157. else:
  158. series_set = set(
  159. noduleInfo_tup.series_uid
  160. for noduleInfo_tup in getNoduleInfoList()
  161. )
  162. train_list = sorted(series_set - val_set) if self.cli_args.include_train else []
  163. val_list = sorted(series_set & val_set)
  164. noduleInfo_list = []
  165. series_iter = enumerateWithEstimate(
  166. val_list + train_list,
  167. "Series",
  168. )
  169. for _series_ndx, series_uid in series_iter:
  170. ct, output_a, _mask_a, clean_a = self.segmentCt(series_uid)
  171. noduleInfo_list += self.clusterSegmentationOutput(
  172. series_uid,
  173. ct,
  174. clean_a,
  175. )
  176. # if _series_ndx > 10:
  177. # break
  178. cls_dl = self.initClassificationDl(noduleInfo_list)
  179. series2diagnosis_dict = {}
  180. batch_iter = enumerateWithEstimate(
  181. cls_dl,
  182. "Cls all",
  183. start_ndx=cls_dl.num_workers,
  184. )
  185. for batch_ndx, batch_tup in batch_iter:
  186. input_t, _, series_list, center_list = batch_tup
  187. input_g = input_t.to(self.device)
  188. with torch.no_grad():
  189. _logits_g, probability_g = self.cls_model(input_g)
  190. classifications_list = zip(
  191. series_list,
  192. center_list,
  193. probability_g[:,1].to('cpu'),
  194. )
  195. for cls_tup in classifications_list:
  196. series_uid, center_irc, probablity_t = cls_tup
  197. probablity_float = probablity_t.item()
  198. this_tup = (probablity_float, tuple(center_irc))
  199. current_tup = series2diagnosis_dict.get(series_uid, this_tup)
  200. try:
  201. assert np.all(np.isfinite(tuple(center_irc)))
  202. if this_tup > current_tup:
  203. log.debug([series_uid, this_tup])
  204. series2diagnosis_dict[series_uid] = max(this_tup, current_tup)
  205. except:
  206. log.debug([(type(x), x) for x in this_tup] + [(type(x), x) for x in current_tup])
  207. raise
  208. log.info('Training set:')
  209. self.logResults('Training', train_list, series2diagnosis_dict, malignant_set)
  210. log.info('Validation set:')
  211. self.logResults('Validation', val_list, series2diagnosis_dict, malignant_set)
  212. def segmentCt(self, series_uid):
  213. with torch.no_grad():
  214. ct = getCt(series_uid)
  215. output_a = np.zeros_like(ct.hu_a, dtype=np.float32)
  216. seg_dl = self.initSegmentationDl(series_uid)
  217. for batch_tup in seg_dl:
  218. input_t = batch_tup[0]
  219. ndx_list = batch_tup[6]
  220. input_g = input_t.to(self.device)
  221. prediction_g = self.seg_model(input_g)
  222. for i, sample_ndx in enumerate(ndx_list):
  223. output_a[sample_ndx] = prediction_g[i].cpu().numpy()
  224. mask_a = output_a > 0.5
  225. clean_a = morph.binary_erosion(mask_a, iterations=1)
  226. clean_a = morph.binary_dilation(clean_a, iterations=2)
  227. return ct, output_a, mask_a, clean_a
  228. def clusterSegmentationOutput(self, series_uid, ct, clean_a):
  229. noduleLabel_a, nodule_count = measure.label(clean_a)
  230. centerIrc_list = measure.center_of_mass(
  231. ct.hu_a + 1001,
  232. labels=noduleLabel_a,
  233. index=list(range(1, nodule_count+1)),
  234. )
  235. # n = 1298
  236. # log.debug([
  237. # (noduleLabel_a == n).sum(),
  238. # np.where(noduleLabel_a == n),
  239. #
  240. # ct.hu_a[noduleLabel_a == n].sum(),
  241. # (ct.hu_a + 1000)[noduleLabel_a == n].sum(),
  242. # ])
  243. # if nodule_count == 1:
  244. # centerIrc_list = [centerIrc_list]
  245. noduleInfo_list = []
  246. for i, center_irc in enumerate(centerIrc_list):
  247. center_xyz = irc2xyz(
  248. center_irc,
  249. ct.origin_xyz,
  250. ct.vxSize_xyz,
  251. ct.direction_tup,
  252. )
  253. assert np.all(np.isfinite(center_irc)), repr(['irc', center_irc, i, nodule_count])
  254. assert np.all(np.isfinite(center_xyz)), repr(['xyz', center_xyz])
  255. noduleInfo_tup = \
  256. NoduleInfoTuple(False, 0.0, series_uid, center_xyz)
  257. noduleInfo_list.append(noduleInfo_tup)
  258. return noduleInfo_list
  259. def logResults(self, mode_str, filtered_list, series2diagnosis_dict, malignant_set):
  260. count_dict = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
  261. for series_uid in filtered_list:
  262. probablity_float, center_irc = series2diagnosis_dict.get(series_uid, (0.0, None))
  263. if center_irc is not None:
  264. center_irc = tuple(int(x.item()) for x in center_irc)
  265. malignant_bool = series_uid in malignant_set
  266. prediction_bool = probablity_float > 0.5
  267. correct_bool = malignant_bool == prediction_bool
  268. if malignant_bool and prediction_bool:
  269. count_dict['tp'] += 1
  270. if not malignant_bool and not prediction_bool:
  271. count_dict['tn'] += 1
  272. if not malignant_bool and prediction_bool:
  273. count_dict['fp'] += 1
  274. if malignant_bool and not prediction_bool:
  275. count_dict['fn'] += 1
  276. log.info("{} {} Mal:{!r:5} Pred:{!r:5} Correct?:{!r:5} Value:{:.4f} {}".format(
  277. mode_str,
  278. series_uid,
  279. malignant_bool,
  280. prediction_bool,
  281. correct_bool,
  282. probablity_float,
  283. center_irc,
  284. ))
  285. total_count = sum(count_dict.values())
  286. percent_dict = {k: v / (total_count or 1) * 100 for k, v in count_dict.items()}
  287. precision = percent_dict['p'] = count_dict['tp'] / ((count_dict['tp'] + count_dict['fp']) or 1)
  288. recall = percent_dict['r'] = count_dict['tp'] / ((count_dict['tp'] + count_dict['fn']) or 1)
  289. percent_dict['f1'] = 2 * (precision * recall) / ((precision + recall) or 1)
  290. log.info(mode_str + " tp:{tp:.1f}%, tn:{tn:.1f}%, fp:{fp:.1f}%, fn:{fn:.1f}%".format(
  291. **percent_dict,
  292. ))
  293. log.info(mode_str + " precision:{p:.3f}, recall:{r:.3f}, F1:{f1:.3f}".format(
  294. **percent_dict,
  295. ))
  296. if __name__ == '__main__':
  297. sys.exit(LunaDiagnoseApp().main() or 0)