check_nodule_fp_rate.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. import argparse
  2. import glob
  3. import hashlib
  4. import math
  5. import os
  6. import sys
  7. import numpy as np
  8. import scipy.ndimage.measurements as measure
  9. import scipy.ndimage.morphology as morph
  10. import torch
  11. import torch.nn as nn
  12. import torch.optim
  13. from torch.utils.data import DataLoader
  14. from util.util import enumerateWithEstimate
  15. # from .dsets import LunaDataset, Luna2dSegmentationDataset, getCt, getCandidateInfoList, CandidateInfoTuple
  16. from p2ch13.dsets import Luna2dSegmentationDataset, getCt, getCandidateInfoList, getCandidateInfoDict, CandidateInfoTuple
  17. from p2ch14.dsets import LunaDataset
  18. from p2ch13.model import UNetWrapper
  19. from p2ch14.model import LunaModel
  20. from util.logconf import logging
  21. from util.util import xyz2irc, irc2xyz
  22. log = logging.getLogger(__name__)
  23. # log.setLevel(logging.WARN)
  24. # log.setLevel(logging.INFO)
  25. log.setLevel(logging.DEBUG)
  26. class FalsePosRateCheckApp:
  27. def __init__(self, sys_argv=None):
  28. if sys_argv is None:
  29. log.debug(sys.argv)
  30. sys_argv = sys.argv[1:]
  31. parser = argparse.ArgumentParser()
  32. parser.add_argument('--batch-size',
  33. help='Batch size to use for training',
  34. default=4,
  35. type=int,
  36. )
  37. parser.add_argument('--num-workers',
  38. help='Number of worker processes for background data loading',
  39. default=8,
  40. type=int,
  41. )
  42. parser.add_argument('--series-uid',
  43. help='Limit inference to this Series UID only.',
  44. default=None,
  45. type=str,
  46. )
  47. parser.add_argument('--include-train',
  48. help="Include data that was in the training set. (default: validation data only)",
  49. action='store_true',
  50. default=False,
  51. )
  52. parser.add_argument('--segmentation-path',
  53. help="Path to the saved segmentation model",
  54. nargs='?',
  55. default=None,
  56. )
  57. parser.add_argument('--classification-path',
  58. help="Path to the saved classification model",
  59. nargs='?',
  60. default=None,
  61. )
  62. parser.add_argument('--tb-prefix',
  63. default='p2ch13',
  64. help="Data prefix to use for Tensorboard run. Defaults to chapter.",
  65. )
  66. self.cli_args = parser.parse_args(sys_argv)
  67. # self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
  68. self.use_cuda = torch.cuda.is_available()
  69. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  70. if not self.cli_args.segmentation_path:
  71. self.cli_args.segmentation_path = self.initModelPath('seg')
  72. if not self.cli_args.classification_path:
  73. self.cli_args.classification_path = self.initModelPath('cls')
  74. self.seg_model, self.cls_model = self.initModels()
  75. def initModelPath(self, type_str):
  76. # local_path = os.path.join(
  77. # 'data-unversioned',
  78. # 'part2',
  79. # 'models',
  80. # self.cli_args.tb_prefix,
  81. # type_str + '_{}_{}.{}.state'.format('*', '*', 'best'),
  82. # )
  83. #
  84. # file_list = glob.glob(local_path)
  85. # if not file_list:
  86. pretrained_path = os.path.join(
  87. 'data',
  88. 'part2',
  89. 'models',
  90. type_str + '_{}_{}.{}.state'.format('*', '*', '*'),
  91. )
  92. file_list = glob.glob(pretrained_path)
  93. # else:
  94. # pretrained_path = None
  95. file_list.sort()
  96. try:
  97. return file_list[-1]
  98. except IndexError:
  99. log.debug([pretrained_path, file_list])
  100. raise
  101. def initModels(self):
  102. with open(self.cli_args.segmentation_path, 'rb') as f:
  103. log.debug(self.cli_args.segmentation_path)
  104. log.debug(hashlib.sha1(f.read()).hexdigest())
  105. seg_dict = torch.load(self.cli_args.segmentation_path)
  106. seg_model = UNetWrapper(
  107. in_channels=7,
  108. n_classes=1,
  109. depth=3,
  110. wf=4,
  111. padding=True,
  112. batch_norm=True,
  113. up_mode='upconv',
  114. )
  115. seg_model.load_state_dict(seg_dict['model_state'])
  116. seg_model.eval()
  117. with open(self.cli_args.classification_path, 'rb') as f:
  118. log.debug(self.cli_args.classification_path)
  119. log.debug(hashlib.sha1(f.read()).hexdigest())
  120. cls_dict = torch.load(self.cli_args.classification_path)
  121. cls_model = LunaModel()
  122. # cls_model = AlternateLunaModel()
  123. cls_model.load_state_dict(cls_dict['model_state'])
  124. cls_model.eval()
  125. if self.use_cuda:
  126. if torch.cuda.device_count() > 1:
  127. seg_model = nn.DataParallel(seg_model)
  128. cls_model = nn.DataParallel(cls_model)
  129. seg_model = seg_model.to(self.device)
  130. cls_model = cls_model.to(self.device)
  131. self.conv_list = nn.ModuleList([
  132. self._make_circle_conv(radius).to(self.device) for radius in range(1, 8)
  133. ])
  134. return seg_model, cls_model
  135. def initSegmentationDl(self, series_uid):
  136. seg_ds = Luna2dSegmentationDataset(
  137. contextSlices_count=3,
  138. series_uid=series_uid,
  139. fullCt_bool=True,
  140. )
  141. seg_dl = DataLoader(
  142. seg_ds,
  143. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  144. num_workers=1, #self.cli_args.num_workers,
  145. pin_memory=self.use_cuda,
  146. )
  147. return seg_dl
  148. def initClassificationDl(self, candidateInfo_list):
  149. cls_ds = LunaDataset(
  150. sortby_str='series_uid',
  151. candidateInfo_list=candidateInfo_list,
  152. )
  153. cls_dl = DataLoader(
  154. cls_ds,
  155. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  156. num_workers=1, #self.cli_args.num_workers,
  157. pin_memory=self.use_cuda,
  158. )
  159. return cls_dl
  160. def main(self):
  161. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  162. val_ds = LunaDataset(
  163. val_stride=10,
  164. isValSet_bool=True,
  165. )
  166. val_set = set(
  167. candidateInfo_tup.series_uid
  168. for candidateInfo_tup in val_ds.candidateInfo_list
  169. )
  170. positive_set = set(
  171. candidateInfo_tup.series_uid
  172. for candidateInfo_tup in getCandidateInfoList()
  173. if candidateInfo_tup.isNodule_bool
  174. )
  175. if self.cli_args.series_uid:
  176. series_set = set(self.cli_args.series_uid.split(','))
  177. else:
  178. series_set = set(
  179. candidateInfo_tup.series_uid
  180. for candidateInfo_tup in getCandidateInfoList()
  181. )
  182. train_list = sorted(series_set - val_set) if self.cli_args.include_train else []
  183. val_list = sorted(series_set & val_set)
  184. total_tp = total_tn = total_fp = total_fn = 0
  185. total_missed_pos = 0
  186. missed_pos_dist_list = []
  187. missed_pos_cit_list = []
  188. candidateInfo_dict = getCandidateInfoDict()
  189. # series2results_dict = {}
  190. # seg_candidateInfo_list = []
  191. series_iter = enumerateWithEstimate(
  192. val_list + train_list,
  193. "Series",
  194. )
  195. for _series_ndx, series_uid in series_iter:
  196. ct, _output_g, _mask_g, clean_g = self.segmentCt(series_uid)
  197. seg_candidateInfo_list, _seg_centerIrc_list, _ = self.clusterSegmentationOutput(
  198. series_uid,
  199. ct,
  200. clean_g,
  201. )
  202. if not seg_candidateInfo_list:
  203. continue
  204. cls_dl = self.initClassificationDl(seg_candidateInfo_list)
  205. results_list = []
  206. # batch_iter = enumerateWithEstimate(
  207. # cls_dl,
  208. # "Cls all",
  209. # start_ndx=cls_dl.num_workers,
  210. # )
  211. # for batch_ndx, batch_tup in batch_iter:
  212. for batch_ndx, batch_tup in enumerate(cls_dl):
  213. input_t, label_t, index_t, series_list, center_t = batch_tup
  214. input_g = input_t.to(self.device)
  215. with torch.no_grad():
  216. _logits_g, probability_g = self.cls_model(input_g)
  217. probability_t = probability_g.to('cpu')
  218. # probability_t = torch.tensor([[0, 1]] * input_t.shape[0], dtype=torch.float32)
  219. for i, _series_uid in enumerate(series_list):
  220. assert series_uid == _series_uid, repr([batch_ndx, i, series_uid, _series_uid, seg_candidateInfo_list])
  221. results_list.append((center_t[i], probability_t[i,0].item()))
  222. # This part is all about matching up annotations with our segmentation results
  223. tp = tn = fp = fn = 0
  224. missed_pos = 0
  225. ct = getCt(series_uid)
  226. candidateInfo_list = candidateInfo_dict[series_uid]
  227. candidateInfo_list = [cit for cit in candidateInfo_list if cit.isNodule_bool]
  228. found_cit_list = [None] * len(results_list)
  229. for candidateInfo_tup in candidateInfo_list:
  230. min_dist = (999, None)
  231. for result_ndx, (result_center_irc_t, nodule_probability_t) in enumerate(results_list):
  232. result_center_xyz = irc2xyz(result_center_irc_t, ct.origin_xyz, ct.vxSize_xyz, ct.direction_a)
  233. delta_xyz_t = torch.tensor(result_center_xyz) - torch.tensor(candidateInfo_tup.center_xyz)
  234. distance_t = (delta_xyz_t ** 2).sum().sqrt()
  235. min_dist = min(min_dist, (distance_t, result_ndx))
  236. distance_cutoff = max(10, candidateInfo_tup.diameter_mm / 2)
  237. if min_dist[0] < distance_cutoff:
  238. found_dist, result_ndx = min_dist
  239. nodule_probability_t = results_list[result_ndx][1]
  240. assert candidateInfo_tup.isNodule_bool
  241. if nodule_probability_t > 0.5:
  242. tp += 1
  243. else:
  244. fn += 1
  245. found_cit_list[result_ndx] = candidateInfo_tup
  246. else:
  247. log.warning("!!! Missed positive {}; {} min dist !!!".format(candidateInfo_tup, min_dist))
  248. missed_pos += 1
  249. missed_pos_dist_list.append(float(min_dist[0]))
  250. missed_pos_cit_list.append(candidateInfo_tup)
  251. # # TODO remove
  252. # acceptable_set = {
  253. # '1.3.6.1.4.1.14519.5.2.1.6279.6001.100225287222365663678666836860',
  254. # '1.3.6.1.4.1.14519.5.2.1.6279.6001.102681962408431413578140925249',
  255. # '1.3.6.1.4.1.14519.5.2.1.6279.6001.195557219224169985110295082004',
  256. # '1.3.6.1.4.1.14519.5.2.1.6279.6001.216252660192313507027754194207',
  257. # # '1.3.6.1.4.1.14519.5.2.1.6279.6001.229096941293122177107846044795',
  258. # '1.3.6.1.4.1.14519.5.2.1.6279.6001.229096941293122177107846044795',
  259. # '1.3.6.1.4.1.14519.5.2.1.6279.6001.299806338046301317870803017534',
  260. # '1.3.6.1.4.1.14519.5.2.1.6279.6001.395623571499047043765181005112',
  261. # '1.3.6.1.4.1.14519.5.2.1.6279.6001.487745546557477250336016826588',
  262. # '1.3.6.1.4.1.14519.5.2.1.6279.6001.970428941353693253759289796610',
  263. # }
  264. # if missed_pos > 0 and series_uid not in acceptable_set:
  265. # log.info("Unacceptable series_uid: " + series_uid)
  266. # break
  267. #
  268. # if total_missed_pos > 10:
  269. # break
  270. #
  271. #
  272. # for result_ndx, (result_center_irc_t, nodule_probability_t) in enumerate(results_list):
  273. # if found_cit_list[result_ndx] is None:
  274. # if nodule_probability_t > 0.5:
  275. # fp += 1
  276. # else:
  277. # tn += 1
  278. log.info("{}: {} missed pos, {} fn, {} fp, {} tp, {} tn".format(series_uid, missed_pos, fn, fp, tp, tn))
  279. total_tp += tp
  280. total_tn += tn
  281. total_fp += fp
  282. total_fn += fn
  283. total_missed_pos += missed_pos
  284. with open(self.cli_args.segmentation_path, 'rb') as f:
  285. log.info(self.cli_args.segmentation_path)
  286. log.info(hashlib.sha1(f.read()).hexdigest())
  287. with open(self.cli_args.classification_path, 'rb') as f:
  288. log.info(self.cli_args.classification_path)
  289. log.info(hashlib.sha1(f.read()).hexdigest())
  290. log.info("{}: {} missed pos, {} fn, {} fp, {} tp, {} tn".format('total', total_missed_pos, total_fn, total_fp, total_tp, total_tn))
  291. # missed_pos_dist_list.sort()
  292. # log.info("missed_pos_dist_list {}".format(missed_pos_dist_list))
  293. for cit, dist in zip(missed_pos_cit_list, missed_pos_dist_list):
  294. log.info(" Missed by {}: {}".format(dist, cit))
  295. def segmentCt(self, series_uid):
  296. with torch.no_grad():
  297. ct = getCt(series_uid)
  298. output_g = torch.zeros(ct.hu_a.shape, dtype=torch.float32, device=self.device)
  299. seg_dl = self.initSegmentationDl(series_uid)
  300. for batch_tup in seg_dl:
  301. input_t, label_t, series_list, slice_ndx_list = batch_tup
  302. input_g = input_t.to(self.device)
  303. prediction_g = self.seg_model(input_g)
  304. for i, slice_ndx in enumerate(slice_ndx_list):
  305. output_g[slice_ndx] = prediction_g[i,0]
  306. mask_g = output_g > 0.5
  307. clean_g = self.erode(mask_g.unsqueeze(0).unsqueeze(0), 1)[0][0]
  308. # mask_a = output_a > 0.5
  309. # clean_a = morph.binary_erosion(mask_a, iterations=1)
  310. # clean_a = morph.binary_dilation(clean_a, iterations=2)
  311. return ct, output_g, mask_g, clean_g
  312. def _make_circle_conv(self, radius):
  313. diameter = 1 + radius * 2
  314. a = torch.linspace(-1, 1, steps=diameter)**2
  315. b = (a[None] + a[:, None])**0.5
  316. circle_weights = (b <= 1.0).to(torch.float32)
  317. conv = nn.Conv3d(1, 1, kernel_size=(1, diameter, diameter), padding=(0, radius, radius), bias=False)
  318. conv.weight.data.fill_(1)
  319. conv.weight.data *= circle_weights / circle_weights.sum()
  320. return conv
  321. def erode(self, input_mask, radius, threshold=1):
  322. conv = self.conv_list[radius - 1]
  323. input_float = input_mask.to(torch.float32)
  324. result = conv(input_float)
  325. # log.debug(['erode in ', radius, threshold, input_float.min().item(), input_float.mean().item(), input_float.max().item()])
  326. # log.debug(['erode out', radius, threshold, result.min().item(), result.mean().item(), result.max().item()])
  327. return result >= threshold
  328. def clusterSegmentationOutput(self, series_uid, ct, clean_g):
  329. clean_a = clean_g.cpu().numpy()
  330. candidateLabel_a, candidate_count = measure.label(clean_a)
  331. centerIrc_list = measure.center_of_mass(
  332. ct.hu_a.clip(-1000, 1000) + 1001,
  333. labels=candidateLabel_a,
  334. index=list(range(1, candidate_count+1)),
  335. )
  336. candidateInfo_list = []
  337. for i, center_irc in enumerate(centerIrc_list):
  338. assert np.isfinite(center_irc).all(), repr([series_uid, i, candidate_count, (ct.hu_a[candidateLabel_a == i+1]).sum(), center_irc])
  339. center_xyz = irc2xyz(
  340. center_irc,
  341. ct.origin_xyz,
  342. ct.vxSize_xyz,
  343. ct.direction_a,
  344. )
  345. diameter_mm = 0.0
  346. # pixel_count = (candidateLabel_a == i+1).sum()
  347. # area_mm2 = pixel_count * ct.vxSize_xyz[0] * ct.vxSize_xyz[1]
  348. # diameter_mm = 2 * (area_mm2 / math.pi) ** 0.5
  349. candidateInfo_tup = \
  350. CandidateInfoTuple(None, None, None, diameter_mm, series_uid, center_xyz)
  351. candidateInfo_list.append(candidateInfo_tup)
  352. return candidateInfo_list, centerIrc_list, candidateLabel_a
  353. # def logResults(self, mode_str, filtered_list, series2diagnosis_dict, positive_set):
  354. # count_dict = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
  355. # for series_uid in filtered_list:
  356. # probablity_float, center_irc = series2diagnosis_dict.get(series_uid, (0.0, None))
  357. # if center_irc is not None:
  358. # center_irc = tuple(int(x.item()) for x in center_irc)
  359. # positive_bool = series_uid in positive_set
  360. # prediction_bool = probablity_float > 0.5
  361. # correct_bool = positive_bool == prediction_bool
  362. #
  363. # if positive_bool and prediction_bool:
  364. # count_dict['tp'] += 1
  365. # if not positive_bool and not prediction_bool:
  366. # count_dict['tn'] += 1
  367. # if not positive_bool and prediction_bool:
  368. # count_dict['fp'] += 1
  369. # if positive_bool and not prediction_bool:
  370. # count_dict['fn'] += 1
  371. #
  372. #
  373. # log.info("{} {} Label:{!r:5} Pred:{!r:5} Correct?:{!r:5} Value:{:.4f} {}".format(
  374. # mode_str,
  375. # series_uid,
  376. # positive_bool,
  377. # prediction_bool,
  378. # correct_bool,
  379. # probablity_float,
  380. # center_irc,
  381. # ))
  382. #
  383. # total_count = sum(count_dict.values())
  384. # percent_dict = {k: v / (total_count or 1) * 100 for k, v in count_dict.items()}
  385. #
  386. # precision = percent_dict['p'] = count_dict['tp'] / ((count_dict['tp'] + count_dict['fp']) or 1)
  387. # recall = percent_dict['r'] = count_dict['tp'] / ((count_dict['tp'] + count_dict['fn']) or 1)
  388. # percent_dict['f1'] = 2 * (precision * recall) / ((precision + recall) or 1)
  389. #
  390. # log.info(mode_str + " tp:{tp:.1f}%, tn:{tn:.1f}%, fp:{fp:.1f}%, fn:{fn:.1f}%".format(
  391. # **percent_dict,
  392. # ))
  393. # log.info(mode_str + " precision:{p:.3f}, recall:{r:.3f}, F1:{f1:.3f}".format(
  394. # **percent_dict,
  395. # ))
  396. if __name__ == '__main__':
  397. FalsePosRateCheckApp().main()