nodule_analysis.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. import argparse
  2. import glob
  3. import os
  4. import sys
  5. import numpy as np
  6. import scipy.ndimage.measurements as measurements
  7. import scipy.ndimage.morphology as morphology
  8. import torch
  9. import torch.nn as nn
  10. import torch.optim
  11. from torch.utils.data import DataLoader
  12. from util.util import enumerateWithEstimate
  13. from p2ch13.dsets import Luna2dSegmentationDataset
  14. from .dsets import LunaDataset, getCt, getCandidateInfoDict, getCandidateInfoList, CandidateInfoTuple
  15. from p2ch13.model import UNetWrapper
  16. import p2ch14.model
  17. from util.logconf import logging
  18. from util.util import xyz2irc, irc2xyz
  19. log = logging.getLogger(__name__)
  20. # log.setLevel(logging.WARN)
  21. # log.setLevel(logging.INFO)
  22. log.setLevel(logging.DEBUG)
  23. logging.getLogger("p2ch13.dsets").setLevel(logging.WARNING)
  24. logging.getLogger("p2ch14.dsets").setLevel(logging.WARNING)
  25. def print_confusion(label, confusions, do_mal):
  26. row_labels = ['Non-Nodules', 'Benign', 'Malignant']
  27. if do_mal:
  28. col_labels = ['', 'Complete Miss', 'Filtered Out', 'Pred. Benign', 'Pred. Malignant']
  29. else:
  30. col_labels = ['', 'Complete Miss', 'Filtered Out', 'Pred. Nodule']
  31. confusions[:, -2] += confusions[:, -1]
  32. confusions = confusions[:, :-1]
  33. cell_width = 16
  34. f = '{:>' + str(cell_width) + '}'
  35. print(label)
  36. print(' | '.join([f.format(s) for s in col_labels]))
  37. for i, (l, r) in enumerate(zip(row_labels, confusions)):
  38. r = [l] + list(r)
  39. if i == 0:
  40. r[1] = ''
  41. print(' | '.join([f.format(i) for i in r]))
  42. def match_and_score(detections, truth, threshold=0.5, threshold_mal=0.5):
  43. # Returns 3x4 confusion matrix for:
  44. # Rows: Truth: Non-Nodules, Benign, Malignant
  45. # Cols: Not Detected, Detected by Seg, Detected as Benign, Detected as Malignant
  46. # If one true nodule matches multiple detections, the "highest" detection is considered
  47. # If one detection matches several true nodule annotations, it counts for all of them
  48. true_nodules = [c for c in truth if c.isNodule_bool]
  49. truth_diams = np.array([c.diameter_mm for c in true_nodules])
  50. truth_xyz = np.array([c.center_xyz for c in true_nodules])
  51. detected_xyz = np.array([n[2] for n in detections])
  52. # detection classes will contain
  53. # 1 -> detected by seg but filtered by cls
  54. # 2 -> detected as benign nodule (or nodule if no malignancy model is used)
  55. # 3 -> detected as malignant nodule (if applicable)
  56. detected_classes = np.array([1 if d[0] < threshold
  57. else (2 if d[1] < threshold
  58. else 3) for d in detections])
  59. confusion = np.zeros((3, 4), dtype=np.int)
  60. if len(detected_xyz) == 0:
  61. for tn in true_nodules:
  62. confusion[2 if tn.isMal_bool else 1, 0] += 1
  63. elif len(truth_xyz) == 0:
  64. for dc in detected_classes:
  65. confusion[0, dc] += 1
  66. else:
  67. normalized_dists = np.linalg.norm(truth_xyz[:, None] - detected_xyz[None], ord=2, axis=-1) / truth_diams[:, None]
  68. matches = (normalized_dists < 0.7)
  69. unmatched_detections = np.ones(len(detections), dtype=np.bool)
  70. matched_true_nodules = np.zeros(len(true_nodules), dtype=np.int)
  71. for i_tn, i_detection in zip(*matches.nonzero()):
  72. matched_true_nodules[i_tn] = max(matched_true_nodules[i_tn], detected_classes[i_detection])
  73. unmatched_detections[i_detection] = False
  74. for ud, dc in zip(unmatched_detections, detected_classes):
  75. if ud:
  76. confusion[0, dc] += 1
  77. for tn, dc in zip(true_nodules, matched_true_nodules):
  78. confusion[2 if tn.isMal_bool else 1, dc] += 1
  79. return confusion
  80. class NoduleAnalysisApp:
  81. def __init__(self, sys_argv=None):
  82. if sys_argv is None:
  83. log.debug(sys.argv)
  84. sys_argv = sys.argv[1:]
  85. parser = argparse.ArgumentParser()
  86. parser.add_argument('--batch-size',
  87. help='Batch size to use for training',
  88. default=4,
  89. type=int,
  90. )
  91. parser.add_argument('--num-workers',
  92. help='Number of worker processes for background data loading',
  93. default=4,
  94. type=int,
  95. )
  96. parser.add_argument('--run-validation',
  97. help='Run over validation rather than a single CT.',
  98. action='store_true',
  99. default=False,
  100. )
  101. parser.add_argument('--include-train',
  102. help="Include data that was in the training set. (default: validation data only)",
  103. action='store_true',
  104. default=False,
  105. )
  106. parser.add_argument('--segmentation-path',
  107. help="Path to the saved segmentation model",
  108. nargs='?',
  109. default='data/part2/models/seg_2020-01-26_19.45.12_w4d3c1-bal_1_nodupe-label_pos-d1_fn8-adam.best.state',
  110. )
  111. parser.add_argument('--cls-model',
  112. help="What to model class name to use for the classifier.",
  113. action='store',
  114. default='LunaModel',
  115. )
  116. parser.add_argument('--classification-path',
  117. help="Path to the saved classification model",
  118. nargs='?',
  119. default='data/part2/models/cls_2020-02-06_14.16.55_final-nodule-nonnodule.best.state',
  120. )
  121. parser.add_argument('--malignancy-model',
  122. help="What to model class name to use for the malignancy classifier.",
  123. action='store',
  124. default='LunaModel',
  125. # default='ModifiedLunaModel',
  126. )
  127. parser.add_argument('--malignancy-path',
  128. help="Path to the saved malignancy classification model",
  129. nargs='?',
  130. default=None,
  131. )
  132. parser.add_argument('--tb-prefix',
  133. default='p2ch14',
  134. help="Data prefix to use for Tensorboard run. Defaults to chapter.",
  135. )
  136. parser.add_argument('series_uid',
  137. nargs='?',
  138. default=None,
  139. help="Series UID to use.",
  140. )
  141. self.cli_args = parser.parse_args(sys_argv)
  142. # self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
  143. if not (bool(self.cli_args.series_uid) ^ self.cli_args.run_validation):
  144. raise Exception("One and only one of series_uid and --run-validation should be given")
  145. self.use_cuda = torch.cuda.is_available()
  146. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  147. if not self.cli_args.segmentation_path:
  148. self.cli_args.segmentation_path = self.initModelPath('seg')
  149. if not self.cli_args.classification_path:
  150. self.cli_args.classification_path = self.initModelPath('cls')
  151. self.seg_model, self.cls_model, self.malignancy_model = self.initModels()
  152. def initModelPath(self, type_str):
  153. local_path = os.path.join(
  154. 'data-unversioned',
  155. 'part2',
  156. 'models',
  157. 'p2ch13',#self.cli_args.tb_prefix,
  158. type_str + '_{}_{}.{}.state'.format('*', '*', 'best'),
  159. )
  160. file_list = glob.glob(local_path)
  161. if not file_list:
  162. pretrained_path = os.path.join(
  163. 'data',
  164. 'part2',
  165. 'models',
  166. type_str + '_{}_{}.{}.state'.format('*', '*', '*'),
  167. )
  168. file_list = glob.glob(pretrained_path)
  169. else:
  170. pretrained_path = None
  171. file_list.sort()
  172. try:
  173. return file_list[-1]
  174. except IndexError:
  175. log.debug([local_path, pretrained_path, file_list])
  176. raise
  177. def initModels(self):
  178. log.debug(self.cli_args.segmentation_path)
  179. seg_dict = torch.load(self.cli_args.segmentation_path)
  180. seg_model = UNetWrapper(in_channels=7, n_classes=1, depth=3, wf=4, padding=True, batch_norm=True, up_mode='upconv')
  181. seg_model.load_state_dict(seg_dict['model_state'])
  182. seg_model.eval()
  183. log.debug(self.cli_args.classification_path)
  184. cls_dict = torch.load(self.cli_args.classification_path)
  185. model_cls = getattr(p2ch14.model, self.cli_args.cls_model)
  186. cls_model = model_cls()
  187. cls_model.load_state_dict(cls_dict['model_state'])
  188. cls_model.eval()
  189. if self.use_cuda:
  190. if torch.cuda.device_count() > 1:
  191. seg_model = nn.DataParallel(seg_model)
  192. cls_model = nn.DataParallel(cls_model)
  193. seg_model.to(self.device)
  194. cls_model.to(self.device)
  195. if self.cli_args.malignancy_path:
  196. model_cls = getattr(p2ch14.model, self.cli_args.malignancy_model)
  197. malignancy_model = model_cls()
  198. malignancy_dict = torch.load(self.cli_args.malignancy_path)
  199. malignancy_model.load_state_dict(malignancy_dict['model_state'])
  200. malignancy_model.eval()
  201. if self.use_cuda:
  202. malignancy_model.to(self.device)
  203. else:
  204. malignancy_model = None
  205. return seg_model, cls_model, malignancy_model
  206. def initSegmentationDl(self, series_uid):
  207. seg_ds = Luna2dSegmentationDataset(
  208. contextSlices_count=3,
  209. series_uid=series_uid,
  210. fullCt_bool=True,
  211. )
  212. seg_dl = DataLoader(
  213. seg_ds,
  214. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  215. num_workers=self.cli_args.num_workers,
  216. pin_memory=self.use_cuda,
  217. )
  218. return seg_dl
  219. def initClassificationDl(self, candidateInfo_list):
  220. cls_ds = LunaDataset(
  221. sortby_str='series_uid',
  222. candidateInfo_list=candidateInfo_list,
  223. )
  224. cls_dl = DataLoader(
  225. cls_ds,
  226. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  227. num_workers=self.cli_args.num_workers,
  228. pin_memory=self.use_cuda,
  229. )
  230. return cls_dl
  231. def main(self):
  232. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  233. val_ds = LunaDataset(
  234. val_stride=10,
  235. isValSet_bool=True,
  236. )
  237. val_set = set(
  238. candidateInfo_tup.series_uid
  239. for candidateInfo_tup in val_ds.candidateInfo_list
  240. )
  241. positive_set = set(
  242. candidateInfo_tup.series_uid
  243. for candidateInfo_tup in getCandidateInfoList()
  244. if candidateInfo_tup.isNodule_bool
  245. )
  246. if self.cli_args.series_uid:
  247. series_set = set(self.cli_args.series_uid.split(','))
  248. else:
  249. series_set = set(
  250. candidateInfo_tup.series_uid
  251. for candidateInfo_tup in getCandidateInfoList()
  252. )
  253. train_list = sorted(series_set - val_set) if self.cli_args.include_train else []
  254. val_list = sorted(series_set & val_set)
  255. candidateInfo_dict = getCandidateInfoDict()
  256. series_iter = enumerateWithEstimate(
  257. val_list + train_list,
  258. "Series",
  259. )
  260. all_confusion = np.zeros((3, 4), dtype=np.int)
  261. for _, series_uid in series_iter:
  262. ct = getCt(series_uid)
  263. mask_a = self.segmentCt(ct, series_uid)
  264. candidateInfo_list = self.groupSegmentationOutput(
  265. series_uid, ct, mask_a
  266. )
  267. classifications_list = self.classifyCandidates(ct, candidateInfo_list)
  268. if not self.cli_args.run_validation:
  269. print(f"found nodule candidates in {series_uid}:")
  270. for prob, prob_mal, center_xyz, center_irc in classifications_list:
  271. if prob > 0.5:
  272. s = f"nodule prob {prob:.3f}, "
  273. if self.malignancy_model:
  274. s += f"malignancy prob {prob_mal:.3f}, "
  275. s += f"center xyz {center_xyz}"
  276. print(s)
  277. if series_uid in candidateInfo_dict:
  278. one_confusion = match_and_score(classifications_list, candidateInfo_dict[series_uid])
  279. all_confusion += one_confusion
  280. print_confusion(series_uid, one_confusion, self.malignancy_model is not None)
  281. print_confusion("Total", all_confusion, self.malignancy_model is not None)
  282. def classifyCandidates(self, ct, candidateInfo_list):
  283. cls_dl = self.initClassificationDl(candidateInfo_list)
  284. classifications_list = []
  285. for batch_ndx, batch_tup in enumerate(cls_dl):
  286. input_t, _, _, series_list, center_list = batch_tup
  287. input_g = input_t.to(self.device)
  288. with torch.no_grad():
  289. _, probability_nodule_g = self.cls_model(input_g)
  290. if self.malignancy_model is not None:
  291. _, probability_mal_g = self.malignancy_model(input_g)
  292. else:
  293. probability_mal_g = torch.zeros_like(probability_nodule_g)
  294. zip_iter = zip(
  295. center_list,
  296. probability_nodule_g[:,1].tolist(),
  297. probability_mal_g[:,1].tolist(),
  298. )
  299. for center_irc, prob_nodule, prob_mal in zip_iter:
  300. center_xyz = irc2xyz(
  301. center_irc,
  302. direction_a=ct.direction_a,
  303. origin_xyz=ct.origin_xyz,
  304. vxSize_xyz=ct.vxSize_xyz,
  305. )
  306. cls_tup = (prob_nodule, prob_mal, center_xyz, center_irc)
  307. classifications_list.append(cls_tup)
  308. return classifications_list
  309. def segmentCt(self, ct, series_uid):
  310. with torch.no_grad():
  311. output_a = np.zeros_like(ct.hu_a, dtype=np.float32)
  312. seg_dl = self.initSegmentationDl(series_uid)
  313. for batch_tup in seg_dl:
  314. input_t, label_t, series_list, slice_ndx_list = batch_tup
  315. input_g = input_t.to(self.device)
  316. prediction_g = self.seg_model(input_g)
  317. for i, slice_ndx in enumerate(slice_ndx_list):
  318. output_a[slice_ndx] = prediction_g[i].cpu().numpy()
  319. mask_a = output_a > 0.5
  320. mask_a = morphology.binary_erosion(mask_a, iterations=1)
  321. return mask_a
  322. def groupSegmentationOutput(self, series_uid, ct, clean_a):
  323. candidateLabel_a, candidate_count = measurements.label(clean_a)
  324. centerIrc_list = measurements.center_of_mass(
  325. ct.hu_a.clip(-1000, 1000) + 1001,
  326. labels=candidateLabel_a,
  327. index=np.arange(1, candidate_count+1),
  328. )
  329. candidateInfo_list = []
  330. for i, center_irc in enumerate(centerIrc_list):
  331. center_xyz = irc2xyz(
  332. center_irc,
  333. ct.origin_xyz,
  334. ct.vxSize_xyz,
  335. ct.direction_a,
  336. )
  337. assert np.all(np.isfinite(center_irc)), repr(['irc', center_irc, i, candidate_count])
  338. assert np.all(np.isfinite(center_xyz)), repr(['xyz', center_xyz])
  339. candidateInfo_tup = \
  340. CandidateInfoTuple(False, False, False, 0.0, series_uid, center_xyz)
  341. candidateInfo_list.append(candidateInfo_tup)
  342. return candidateInfo_list
  343. def logResults(self, mode_str, filtered_list, series2diagnosis_dict, positive_set):
  344. count_dict = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
  345. for series_uid in filtered_list:
  346. probablity_float, center_irc = series2diagnosis_dict.get(series_uid, (0.0, None))
  347. if center_irc is not None:
  348. center_irc = tuple(int(x.item()) for x in center_irc)
  349. positive_bool = series_uid in positive_set
  350. prediction_bool = probablity_float > 0.5
  351. correct_bool = positive_bool == prediction_bool
  352. if positive_bool and prediction_bool:
  353. count_dict['tp'] += 1
  354. if not positive_bool and not prediction_bool:
  355. count_dict['tn'] += 1
  356. if not positive_bool and prediction_bool:
  357. count_dict['fp'] += 1
  358. if positive_bool and not prediction_bool:
  359. count_dict['fn'] += 1
  360. log.info("{} {} Label:{!r:5} Pred:{!r:5} Correct?:{!r:5} Value:{:.4f} {}".format(
  361. mode_str,
  362. series_uid,
  363. positive_bool,
  364. prediction_bool,
  365. correct_bool,
  366. probablity_float,
  367. center_irc,
  368. ))
  369. total_count = sum(count_dict.values())
  370. percent_dict = {k: v / (total_count or 1) * 100 for k, v in count_dict.items()}
  371. precision = percent_dict['p'] = count_dict['tp'] / ((count_dict['tp'] + count_dict['fp']) or 1)
  372. recall = percent_dict['r'] = count_dict['tp'] / ((count_dict['tp'] + count_dict['fn']) or 1)
  373. percent_dict['f1'] = 2 * (precision * recall) / ((precision + recall) or 1)
  374. log.info(mode_str + " tp:{tp:.1f}%, tn:{tn:.1f}%, fp:{fp:.1f}%, fn:{fn:.1f}%".format(
  375. **percent_dict,
  376. ))
  377. log.info(mode_str + " precision:{p:.3f}, recall:{r:.3f}, F1:{f1:.3f}".format(
  378. **percent_dict,
  379. ))
  380. if __name__ == '__main__':
  381. NoduleAnalysisApp().main()