| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488 |
- import argparse
- import glob
- import hashlib
- import math
- import os
- import sys
- import numpy as np
- import scipy.ndimage.measurements as measure
- import scipy.ndimage.morphology as morph
- import torch
- import torch.nn as nn
- import torch.optim
- from torch.utils.data import DataLoader
- from util.util import enumerateWithEstimate
- # from .dsets import LunaDataset, Luna2dSegmentationDataset, getCt, getCandidateInfoList, CandidateInfoTuple
- from p2ch13.dsets import Luna2dSegmentationDataset, getCt, getCandidateInfoList, getCandidateInfoDict, CandidateInfoTuple
- from p2ch14.dsets import LunaDataset
- from p2ch13.model import UNetWrapper
- from p2ch14.model import LunaModel
- from util.logconf import logging
- from util.util import xyz2irc, irc2xyz
- log = logging.getLogger(__name__)
- # log.setLevel(logging.WARN)
- # log.setLevel(logging.INFO)
- log.setLevel(logging.DEBUG)
- class FalsePosRateCheckApp:
- def __init__(self, sys_argv=None):
- if sys_argv is None:
- log.debug(sys.argv)
- sys_argv = sys.argv[1:]
- parser = argparse.ArgumentParser()
- parser.add_argument('--batch-size',
- help='Batch size to use for training',
- default=4,
- type=int,
- )
- parser.add_argument('--num-workers',
- help='Number of worker processes for background data loading',
- default=8,
- type=int,
- )
- parser.add_argument('--series-uid',
- help='Limit inference to this Series UID only.',
- default=None,
- type=str,
- )
- parser.add_argument('--include-train',
- help="Include data that was in the training set. (default: validation data only)",
- action='store_true',
- default=False,
- )
- parser.add_argument('--segmentation-path',
- help="Path to the saved segmentation model",
- nargs='?',
- default=None,
- )
- parser.add_argument('--classification-path',
- help="Path to the saved classification model",
- nargs='?',
- default=None,
- )
- parser.add_argument('--tb-prefix',
- default='p2ch13',
- help="Data prefix to use for Tensorboard run. Defaults to chapter.",
- )
- self.cli_args = parser.parse_args(sys_argv)
- # self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
- self.use_cuda = torch.cuda.is_available()
- self.device = torch.device("cuda" if self.use_cuda else "cpu")
- if not self.cli_args.segmentation_path:
- self.cli_args.segmentation_path = self.initModelPath('seg')
- if not self.cli_args.classification_path:
- self.cli_args.classification_path = self.initModelPath('cls')
- self.seg_model, self.cls_model = self.initModels()
- def initModelPath(self, type_str):
- # local_path = os.path.join(
- # 'data-unversioned',
- # 'part2',
- # 'models',
- # self.cli_args.tb_prefix,
- # type_str + '_{}_{}.{}.state'.format('*', '*', 'best'),
- # )
- #
- # file_list = glob.glob(local_path)
- # if not file_list:
- pretrained_path = os.path.join(
- 'data',
- 'part2',
- 'models',
- type_str + '_{}_{}.{}.state'.format('*', '*', '*'),
- )
- file_list = glob.glob(pretrained_path)
- # else:
- # pretrained_path = None
- file_list.sort()
- try:
- return file_list[-1]
- except IndexError:
- log.debug([pretrained_path, file_list])
- raise
- def initModels(self):
- with open(self.cli_args.segmentation_path, 'rb') as f:
- log.debug(self.cli_args.segmentation_path)
- log.debug(hashlib.sha1(f.read()).hexdigest())
- seg_dict = torch.load(self.cli_args.segmentation_path)
- seg_model = UNetWrapper(
- in_channels=7,
- n_classes=1,
- depth=3,
- wf=4,
- padding=True,
- batch_norm=True,
- up_mode='upconv',
- )
- seg_model.load_state_dict(seg_dict['model_state'])
- seg_model.eval()
- with open(self.cli_args.classification_path, 'rb') as f:
- log.debug(self.cli_args.classification_path)
- log.debug(hashlib.sha1(f.read()).hexdigest())
- cls_dict = torch.load(self.cli_args.classification_path)
- cls_model = LunaModel()
- # cls_model = AlternateLunaModel()
- cls_model.load_state_dict(cls_dict['model_state'])
- cls_model.eval()
- if self.use_cuda:
- if torch.cuda.device_count() > 1:
- seg_model = nn.DataParallel(seg_model)
- cls_model = nn.DataParallel(cls_model)
- seg_model = seg_model.to(self.device)
- cls_model = cls_model.to(self.device)
- self.conv_list = nn.ModuleList([
- self._make_circle_conv(radius).to(self.device) for radius in range(1, 8)
- ])
- return seg_model, cls_model
- def initSegmentationDl(self, series_uid):
- seg_ds = Luna2dSegmentationDataset(
- contextSlices_count=3,
- series_uid=series_uid,
- fullCt_bool=True,
- )
- seg_dl = DataLoader(
- seg_ds,
- batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
- num_workers=1, #self.cli_args.num_workers,
- pin_memory=self.use_cuda,
- )
- return seg_dl
- def initClassificationDl(self, candidateInfo_list):
- cls_ds = LunaDataset(
- sortby_str='series_uid',
- candidateInfo_list=candidateInfo_list,
- )
- cls_dl = DataLoader(
- cls_ds,
- batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
- num_workers=1, #self.cli_args.num_workers,
- pin_memory=self.use_cuda,
- )
- return cls_dl
- def main(self):
- log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
- val_ds = LunaDataset(
- val_stride=10,
- isValSet_bool=True,
- )
- val_set = set(
- candidateInfo_tup.series_uid
- for candidateInfo_tup in val_ds.candidateInfo_list
- )
- positive_set = set(
- candidateInfo_tup.series_uid
- for candidateInfo_tup in getCandidateInfoList()
- if candidateInfo_tup.isNodule_bool
- )
- if self.cli_args.series_uid:
- series_set = set(self.cli_args.series_uid.split(','))
- else:
- series_set = set(
- candidateInfo_tup.series_uid
- for candidateInfo_tup in getCandidateInfoList()
- )
- train_list = sorted(series_set - val_set) if self.cli_args.include_train else []
- val_list = sorted(series_set & val_set)
- total_tp = total_tn = total_fp = total_fn = 0
- total_missed_pos = 0
- missed_pos_dist_list = []
- missed_pos_cit_list = []
- candidateInfo_dict = getCandidateInfoDict()
- # series2results_dict = {}
- # seg_candidateInfo_list = []
- series_iter = enumerateWithEstimate(
- val_list + train_list,
- "Series",
- )
- for _series_ndx, series_uid in series_iter:
- ct, _output_g, _mask_g, clean_g = self.segmentCt(series_uid)
- seg_candidateInfo_list, _seg_centerIrc_list, _ = self.clusterSegmentationOutput(
- series_uid,
- ct,
- clean_g,
- )
- if not seg_candidateInfo_list:
- continue
- cls_dl = self.initClassificationDl(seg_candidateInfo_list)
- results_list = []
- # batch_iter = enumerateWithEstimate(
- # cls_dl,
- # "Cls all",
- # start_ndx=cls_dl.num_workers,
- # )
- # for batch_ndx, batch_tup in batch_iter:
- for batch_ndx, batch_tup in enumerate(cls_dl):
- input_t, label_t, index_t, series_list, center_t = batch_tup
- input_g = input_t.to(self.device)
- with torch.no_grad():
- _logits_g, probability_g = self.cls_model(input_g)
- probability_t = probability_g.to('cpu')
- # probability_t = torch.tensor([[0, 1]] * input_t.shape[0], dtype=torch.float32)
- for i, _series_uid in enumerate(series_list):
- assert series_uid == _series_uid, repr([batch_ndx, i, series_uid, _series_uid, seg_candidateInfo_list])
- results_list.append((center_t[i], probability_t[i,0].item()))
- # This part is all about matching up annotations with our segmentation results
- tp = tn = fp = fn = 0
- missed_pos = 0
- ct = getCt(series_uid)
- candidateInfo_list = candidateInfo_dict[series_uid]
- candidateInfo_list = [cit for cit in candidateInfo_list if cit.isNodule_bool]
- found_cit_list = [None] * len(results_list)
- for candidateInfo_tup in candidateInfo_list:
- min_dist = (999, None)
- for result_ndx, (result_center_irc_t, nodule_probability_t) in enumerate(results_list):
- result_center_xyz = irc2xyz(result_center_irc_t, ct.origin_xyz, ct.vxSize_xyz, ct.direction_a)
- delta_xyz_t = torch.tensor(result_center_xyz) - torch.tensor(candidateInfo_tup.center_xyz)
- distance_t = (delta_xyz_t ** 2).sum().sqrt()
- min_dist = min(min_dist, (distance_t, result_ndx))
- distance_cutoff = max(10, candidateInfo_tup.diameter_mm / 2)
- if min_dist[0] < distance_cutoff:
- found_dist, result_ndx = min_dist
- nodule_probability_t = results_list[result_ndx][1]
- assert candidateInfo_tup.isNodule_bool
- if nodule_probability_t > 0.5:
- tp += 1
- else:
- fn += 1
- found_cit_list[result_ndx] = candidateInfo_tup
- else:
- log.warning("!!! Missed positive {}; {} min dist !!!".format(candidateInfo_tup, min_dist))
- missed_pos += 1
- missed_pos_dist_list.append(float(min_dist[0]))
- missed_pos_cit_list.append(candidateInfo_tup)
- # # TODO remove
- # acceptable_set = {
- # '1.3.6.1.4.1.14519.5.2.1.6279.6001.100225287222365663678666836860',
- # '1.3.6.1.4.1.14519.5.2.1.6279.6001.102681962408431413578140925249',
- # '1.3.6.1.4.1.14519.5.2.1.6279.6001.195557219224169985110295082004',
- # '1.3.6.1.4.1.14519.5.2.1.6279.6001.216252660192313507027754194207',
- # # '1.3.6.1.4.1.14519.5.2.1.6279.6001.229096941293122177107846044795',
- # '1.3.6.1.4.1.14519.5.2.1.6279.6001.229096941293122177107846044795',
- # '1.3.6.1.4.1.14519.5.2.1.6279.6001.299806338046301317870803017534',
- # '1.3.6.1.4.1.14519.5.2.1.6279.6001.395623571499047043765181005112',
- # '1.3.6.1.4.1.14519.5.2.1.6279.6001.487745546557477250336016826588',
- # '1.3.6.1.4.1.14519.5.2.1.6279.6001.970428941353693253759289796610',
- # }
- # if missed_pos > 0 and series_uid not in acceptable_set:
- # log.info("Unacceptable series_uid: " + series_uid)
- # break
- #
- # if total_missed_pos > 10:
- # break
- #
- #
- # for result_ndx, (result_center_irc_t, nodule_probability_t) in enumerate(results_list):
- # if found_cit_list[result_ndx] is None:
- # if nodule_probability_t > 0.5:
- # fp += 1
- # else:
- # tn += 1
- log.info("{}: {} missed pos, {} fn, {} fp, {} tp, {} tn".format(series_uid, missed_pos, fn, fp, tp, tn))
- total_tp += tp
- total_tn += tn
- total_fp += fp
- total_fn += fn
- total_missed_pos += missed_pos
- with open(self.cli_args.segmentation_path, 'rb') as f:
- log.info(self.cli_args.segmentation_path)
- log.info(hashlib.sha1(f.read()).hexdigest())
- with open(self.cli_args.classification_path, 'rb') as f:
- log.info(self.cli_args.classification_path)
- log.info(hashlib.sha1(f.read()).hexdigest())
- log.info("{}: {} missed pos, {} fn, {} fp, {} tp, {} tn".format('total', total_missed_pos, total_fn, total_fp, total_tp, total_tn))
- # missed_pos_dist_list.sort()
- # log.info("missed_pos_dist_list {}".format(missed_pos_dist_list))
- for cit, dist in zip(missed_pos_cit_list, missed_pos_dist_list):
- log.info(" Missed by {}: {}".format(dist, cit))
- def segmentCt(self, series_uid):
- with torch.no_grad():
- ct = getCt(series_uid)
- output_g = torch.zeros(ct.hu_a.shape, dtype=torch.float32, device=self.device)
- seg_dl = self.initSegmentationDl(series_uid)
- for batch_tup in seg_dl:
- input_t, label_t, series_list, slice_ndx_list = batch_tup
- input_g = input_t.to(self.device)
- prediction_g = self.seg_model(input_g)
- for i, slice_ndx in enumerate(slice_ndx_list):
- output_g[slice_ndx] = prediction_g[i,0]
- mask_g = output_g > 0.5
- clean_g = self.erode(mask_g.unsqueeze(0).unsqueeze(0), 1)[0][0]
- # mask_a = output_a > 0.5
- # clean_a = morph.binary_erosion(mask_a, iterations=1)
- # clean_a = morph.binary_dilation(clean_a, iterations=2)
- return ct, output_g, mask_g, clean_g
- def _make_circle_conv(self, radius):
- diameter = 1 + radius * 2
- a = torch.linspace(-1, 1, steps=diameter)**2
- b = (a[None] + a[:, None])**0.5
- circle_weights = (b <= 1.0).to(torch.float32)
- conv = nn.Conv3d(1, 1, kernel_size=(1, diameter, diameter), padding=(0, radius, radius), bias=False)
- conv.weight.data.fill_(1)
- conv.weight.data *= circle_weights / circle_weights.sum()
- return conv
- def erode(self, input_mask, radius, threshold=1):
- conv = self.conv_list[radius - 1]
- input_float = input_mask.to(torch.float32)
- result = conv(input_float)
- # log.debug(['erode in ', radius, threshold, input_float.min().item(), input_float.mean().item(), input_float.max().item()])
- # log.debug(['erode out', radius, threshold, result.min().item(), result.mean().item(), result.max().item()])
- return result >= threshold
- def clusterSegmentationOutput(self, series_uid, ct, clean_g):
- clean_a = clean_g.cpu().numpy()
- candidateLabel_a, candidate_count = measure.label(clean_a)
- centerIrc_list = measure.center_of_mass(
- ct.hu_a.clip(-1000, 1000) + 1001,
- labels=candidateLabel_a,
- index=list(range(1, candidate_count+1)),
- )
- candidateInfo_list = []
- for i, center_irc in enumerate(centerIrc_list):
- assert np.isfinite(center_irc).all(), repr([series_uid, i, candidate_count, (ct.hu_a[candidateLabel_a == i+1]).sum(), center_irc])
- center_xyz = irc2xyz(
- center_irc,
- ct.origin_xyz,
- ct.vxSize_xyz,
- ct.direction_a,
- )
- diameter_mm = 0.0
- # pixel_count = (candidateLabel_a == i+1).sum()
- # area_mm2 = pixel_count * ct.vxSize_xyz[0] * ct.vxSize_xyz[1]
- # diameter_mm = 2 * (area_mm2 / math.pi) ** 0.5
- candidateInfo_tup = \
- CandidateInfoTuple(None, None, None, diameter_mm, series_uid, center_xyz)
- candidateInfo_list.append(candidateInfo_tup)
- return candidateInfo_list, centerIrc_list, candidateLabel_a
- # def logResults(self, mode_str, filtered_list, series2diagnosis_dict, positive_set):
- # count_dict = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
- # for series_uid in filtered_list:
- # probablity_float, center_irc = series2diagnosis_dict.get(series_uid, (0.0, None))
- # if center_irc is not None:
- # center_irc = tuple(int(x.item()) for x in center_irc)
- # positive_bool = series_uid in positive_set
- # prediction_bool = probablity_float > 0.5
- # correct_bool = positive_bool == prediction_bool
- #
- # if positive_bool and prediction_bool:
- # count_dict['tp'] += 1
- # if not positive_bool and not prediction_bool:
- # count_dict['tn'] += 1
- # if not positive_bool and prediction_bool:
- # count_dict['fp'] += 1
- # if positive_bool and not prediction_bool:
- # count_dict['fn'] += 1
- #
- #
- # log.info("{} {} Label:{!r:5} Pred:{!r:5} Correct?:{!r:5} Value:{:.4f} {}".format(
- # mode_str,
- # series_uid,
- # positive_bool,
- # prediction_bool,
- # correct_bool,
- # probablity_float,
- # center_irc,
- # ))
- #
- # total_count = sum(count_dict.values())
- # percent_dict = {k: v / (total_count or 1) * 100 for k, v in count_dict.items()}
- #
- # precision = percent_dict['p'] = count_dict['tp'] / ((count_dict['tp'] + count_dict['fp']) or 1)
- # recall = percent_dict['r'] = count_dict['tp'] / ((count_dict['tp'] + count_dict['fn']) or 1)
- # percent_dict['f1'] = 2 * (precision * recall) / ((precision + recall) or 1)
- #
- # log.info(mode_str + " tp:{tp:.1f}%, tn:{tn:.1f}%, fp:{fp:.1f}%, fn:{fn:.1f}%".format(
- # **percent_dict,
- # ))
- # log.info(mode_str + " precision:{p:.3f}, recall:{r:.3f}, F1:{f1:.3f}".format(
- # **percent_dict,
- # ))
- if __name__ == '__main__':
- FalsePosRateCheckApp().main()
|