|
|
@@ -31,16 +31,15 @@ logging.getLogger("p2ch13.dsets").setLevel(logging.WARNING)
|
|
|
logging.getLogger("p2ch14.dsets").setLevel(logging.WARNING)
|
|
|
|
|
|
def print_confusion(label, confusions, do_mal):
|
|
|
+ row_labels = ['Non-Nodules', 'Benign', 'Malignant']
|
|
|
+
|
|
|
if do_mal:
|
|
|
- col_labels = ['', 'Complete Miss', 'Filtered', 'Benign', 'Malignant']
|
|
|
- row_labels = ['Non-Nodules', 'Benign', 'Malignant']
|
|
|
+ col_labels = ['', 'Complete Miss', 'Filtered Out', 'Pred. Benign', 'Pred. Malignant']
|
|
|
else:
|
|
|
- col_labels = ['', 'Complete Miss', 'Filtered', 'Detected']
|
|
|
- row_labels = ['Non-Nodules', 'Nodules']
|
|
|
- confusions[-2] += confusions[-1]
|
|
|
+ col_labels = ['', 'Complete Miss', 'Filtered Out', 'Pred. Nodule']
|
|
|
confusions[:, -2] += confusions[:, -1]
|
|
|
- confusions = confusions[:-1, :-1]
|
|
|
- cell_width = 14
|
|
|
+ confusions = confusions[:, :-1]
|
|
|
+ cell_width = 16
|
|
|
f = '{:>' + str(cell_width) + '}'
|
|
|
print(label)
|
|
|
print(' | '.join([f.format(s) for s in col_labels]))
|
|
|
@@ -72,7 +71,7 @@ def match_and_score(detections, truth, threshold=0.5, threshold_mal=0.5):
|
|
|
confusion = np.zeros((3, 4), dtype=np.int)
|
|
|
if len(detected_xyz) == 0:
|
|
|
for tn in true_nodules:
|
|
|
- confusiion[2 if tn.isMal_bool else 1, 0] += 1
|
|
|
+ confusion[2 if tn.isMal_bool else 1, 0] += 1
|
|
|
elif len(truth_xyz) == 0:
|
|
|
for dc in detected_classes:
|
|
|
confusion[0, dc] += 1
|
|
|
@@ -124,7 +123,7 @@ class NoduleAnalysisApp:
|
|
|
parser.add_argument('--segmentation-path',
|
|
|
help="Path to the saved segmentation model",
|
|
|
nargs='?',
|
|
|
- default=None,
|
|
|
+ default='data/part2/models/seg_2020-01-26_19.45.12_w4d3c1-bal_1_nodupe-label_pos-d1_fn8-adam.best.state',
|
|
|
)
|
|
|
|
|
|
parser.add_argument('--cls-model',
|
|
|
@@ -135,13 +134,14 @@ class NoduleAnalysisApp:
|
|
|
parser.add_argument('--classification-path',
|
|
|
help="Path to the saved classification model",
|
|
|
nargs='?',
|
|
|
- default=None,
|
|
|
+ default='data/part2/models/cls_2020-02-06_14.16.55_final-nodule-nonnodule.best.state',
|
|
|
)
|
|
|
|
|
|
parser.add_argument('--malignancy-model',
|
|
|
help="What to model class name to use for the malignancy classifier.",
|
|
|
action='store',
|
|
|
- default='ModifiedLunaModel',
|
|
|
+ default='LunaModel',
|
|
|
+ # default='ModifiedLunaModel',
|
|
|
)
|
|
|
parser.add_argument('--malignancy-path',
|
|
|
help="Path to the saved malignancy classification model",
|
|
|
@@ -303,7 +303,6 @@ class NoduleAnalysisApp:
|
|
|
val_list = sorted(series_set & val_set)
|
|
|
|
|
|
|
|
|
- candidateInfo_list = []
|
|
|
candidateInfo_dict = getCandidateInfoDict()
|
|
|
series_iter = enumerateWithEstimate(
|
|
|
val_list + train_list,
|
|
|
@@ -314,10 +313,8 @@ class NoduleAnalysisApp:
|
|
|
ct = getCt(series_uid)
|
|
|
mask_a = self.segmentCt(ct, series_uid)
|
|
|
|
|
|
- candidateInfo_list = self.clusterSegmentationOutput(
|
|
|
- series_uid,
|
|
|
- ct,
|
|
|
- mask_a,
|
|
|
+ candidateInfo_list = self.groupSegmentationOutput(
|
|
|
+ series_uid, ct, mask_a
|
|
|
)
|
|
|
classifications_list = self.classifyCandidates(ct, candidateInfo_list)
|
|
|
|
|
|
@@ -339,7 +336,6 @@ class NoduleAnalysisApp:
|
|
|
print_confusion("Total", all_confusion, self.malignancy_model is not None)
|
|
|
|
|
|
|
|
|
-
|
|
|
def classifyCandidates(self, ct, candidateInfo_list):
|
|
|
cls_dl = self.initClassificationDl(candidateInfo_list)
|
|
|
classifications_list = []
|
|
|
@@ -348,22 +344,26 @@ class NoduleAnalysisApp:
|
|
|
|
|
|
input_g = input_t.to(self.device)
|
|
|
with torch.no_grad():
|
|
|
- _, probability_g = self.cls_model(input_g)
|
|
|
+ _, probability_nodule_g = self.cls_model(input_g)
|
|
|
if self.malignancy_model is not None:
|
|
|
_, probability_mal_g = self.malignancy_model(input_g)
|
|
|
else:
|
|
|
- probability_mal_g = torch.zeros_like(probability_g)
|
|
|
+ probability_mal_g = torch.zeros_like(probability_nodule_g)
|
|
|
|
|
|
- for center_irc, prob, prob_mal in zip(center_list,
|
|
|
- probability_g[:,1].tolist(),
|
|
|
- probability_mal_g[:,1].tolist()
|
|
|
- ):
|
|
|
+ zip_iter = zip(
|
|
|
+ center_list,
|
|
|
+ probability_nodule_g[:,1].tolist(),
|
|
|
+ probability_mal_g[:,1].tolist(),
|
|
|
+ )
|
|
|
+ for center_irc, prob_nodule, prob_mal in zip_iter:
|
|
|
center_xyz = irc2xyz(
|
|
|
center_irc,
|
|
|
direction_a=ct.direction_a,
|
|
|
origin_xyz=ct.origin_xyz,
|
|
|
- vxSize_xyz=ct.vxSize_xyz)
|
|
|
- classifications_list.append((prob, prob_mal, center_xyz, center_irc))
|
|
|
+ vxSize_xyz=ct.vxSize_xyz,
|
|
|
+ )
|
|
|
+ cls_tup = (prob_nodule, prob_mal, center_xyz, center_irc)
|
|
|
+ classifications_list.append(cls_tup)
|
|
|
return classifications_list
|
|
|
|
|
|
def segmentCt(self, ct, series_uid):
|
|
|
@@ -371,26 +371,23 @@ class NoduleAnalysisApp:
|
|
|
output_a = np.zeros_like(ct.hu_a, dtype=np.float32)
|
|
|
seg_dl = self.initSegmentationDl(series_uid)
|
|
|
for batch_tup in seg_dl:
|
|
|
- input_t = batch_tup[0]
|
|
|
- ndx_list = batch_tup[4]
|
|
|
+ input_t, label_t, series_list, slice_ndx_list = batch_tup
|
|
|
|
|
|
input_g = input_t.to(self.device)
|
|
|
prediction_g = self.seg_model(input_g)
|
|
|
|
|
|
- for i, sample_ndx in enumerate(ndx_list):
|
|
|
- output_a[sample_ndx] = prediction_g[i].cpu().numpy()
|
|
|
+ for i, slice_ndx in enumerate(slice_ndx_list):
|
|
|
+ output_a[slice_ndx] = prediction_g[i].cpu().numpy()
|
|
|
|
|
|
- # mask_a = output_a > 0.25
|
|
|
mask_a = output_a > 0.5
|
|
|
- # mask_a = morphology.binary_erosion(mask_a, iterations=1)
|
|
|
- # mask_a = morphology.binary_dilation(mask_a, iterations=2)
|
|
|
+ mask_a = morphology.binary_erosion(mask_a, iterations=1)
|
|
|
|
|
|
return mask_a
|
|
|
|
|
|
- def clusterSegmentationOutput(self, series_uid, ct, clean_a):
|
|
|
+ def groupSegmentationOutput(self, series_uid, ct, clean_a):
|
|
|
candidateLabel_a, candidate_count = measurements.label(clean_a)
|
|
|
centerIrc_list = measurements.center_of_mass(
|
|
|
- ct.hu_a + 1001,
|
|
|
+ ct.hu_a.clip(-1000, 1000) + 1001,
|
|
|
labels=candidateLabel_a,
|
|
|
index=np.arange(1, candidate_count+1),
|
|
|
)
|