|
|
@@ -0,0 +1,593 @@
|
|
|
+import argparse
|
|
|
+import datetime
|
|
|
+import glob
|
|
|
+import os
|
|
|
+import sys
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+from tensorboardX import SummaryWriter
|
|
|
+
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+import torch.optim
|
|
|
+
|
|
|
+from torch.optim import SGD, Adam
|
|
|
+from torch.utils.data import DataLoader
|
|
|
+
|
|
|
+from util.util import enumerateWithEstimate
|
|
|
+from .dsets import Luna2dSegmentationDataset, LunaClassificationDataset, getCt, getNoduleInfoList
|
|
|
+from util.logconf import logging
|
|
|
+from util.util import xyz2irc, irc2xyz
|
|
|
+from .model import UNetWrapper, LunaModel
|
|
|
+
|
|
|
+log = logging.getLogger(__name__)
|
|
|
+# log.setLevel(logging.WARN)
|
|
|
+# log.setLevel(logging.INFO)
|
|
|
+log.setLevel(logging.DEBUG)
|
|
|
+
|
|
|
+# Used for computeBatchLoss and logMetrics to index into metrics_tensor/metrics_ary
|
|
|
+# METRICS_LABEL_NDX=0
|
|
|
+# METRICS_PRED_NDX=1
|
|
|
+# METRICS_LOSS_NDX=2
|
|
|
+# METRICS_MAL_LOSS_NDX=3
|
|
|
+# METRICS_BEN_LOSS_NDX=4
|
|
|
+# METRICS_LUNG_LOSS_NDX=5
|
|
|
+# METRICS_MASKLOSS_NDX=2
|
|
|
+# METRICS_MALLOSS_NDX=3
|
|
|
+
|
|
|
+
|
|
|
+METRICS_LOSS_NDX = 0
|
|
|
+METRICS_LABEL_NDX = 1
|
|
|
+METRICS_MFOUND_NDX = 2
|
|
|
+
|
|
|
+METRICS_MOK_NDX = 3
|
|
|
+METRICS_MTP_NDX = 4
|
|
|
+METRICS_MFN_NDX = 5
|
|
|
+METRICS_MFP_NDX = 6
|
|
|
+METRICS_BTP_NDX = 7
|
|
|
+METRICS_BFN_NDX = 8
|
|
|
+METRICS_BFP_NDX = 9
|
|
|
+
|
|
|
+METRICS_MAL_LOSS_NDX = 10
|
|
|
+METRICS_BEN_LOSS_NDX = 11
|
|
|
+METRICS_SIZE = 12
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+class LunaDiagnoseApp(object):
|
|
|
+ def __init__(self, sys_argv=None):
|
|
|
+ if sys_argv is None:
|
|
|
+ log.debug(sys.argv)
|
|
|
+ sys_argv = sys.argv[1:]
|
|
|
+
|
|
|
+ parser = argparse.ArgumentParser()
|
|
|
+ parser.add_argument('--batch-size',
|
|
|
+ help='Batch size to use for training',
|
|
|
+ default=4,
|
|
|
+ type=int,
|
|
|
+ )
|
|
|
+ parser.add_argument('--num-workers',
|
|
|
+ help='Number of worker processes for background data loading',
|
|
|
+ default=8,
|
|
|
+ type=int,
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument('--series-uid',
|
|
|
+ help='Limit inference to this Series UID only.',
|
|
|
+ default=None,
|
|
|
+ type=str,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+ parser.add_argument('segmentation_path',
|
|
|
+ help="Path to the saved segmentation model",
|
|
|
+ nargs='?',
|
|
|
+ default=None,
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument('classification_path',
|
|
|
+ help="Path to the saved classification model",
|
|
|
+ nargs='?',
|
|
|
+ default=None,
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument('--tb-prefix',
|
|
|
+ default='p2ch10',
|
|
|
+ help="Data prefix to use for Tensorboard run. Defaults to chapter.",
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+ self.cli_args = parser.parse_args(sys_argv)
|
|
|
+ # self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
|
|
|
+
|
|
|
+ self.use_cuda = torch.cuda.is_available()
|
|
|
+ self.device = torch.device("cuda" if self.use_cuda else "cpu")
|
|
|
+
|
|
|
+ # self.optimizer = self.initOptimizer()
|
|
|
+
|
|
|
+ if not self.cli_args.segmentation_path:
|
|
|
+ file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, 'seg_{}_{}.{}.state'.format('*', '*', 'best'))
|
|
|
+ # log.debug(file_path)
|
|
|
+ self.cli_args.segmentation_path = glob.glob(file_path)[-1]
|
|
|
+
|
|
|
+ log.debug(self.cli_args.segmentation_path)
|
|
|
+
|
|
|
+ # if not self.cli_args.classification_path:
|
|
|
+ # file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, 'cls_{}_{}.{}.state'.format('*', '*', 'best'))
|
|
|
+ # self.cli_args.classification_path = glob.glob(file_path)[-1]
|
|
|
+
|
|
|
+ self.seg_model, self.cls_model = self.initModels()
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ def initModels(self):
|
|
|
+ log.debug(self.cli_args.segmentation_path)
|
|
|
+ seg_dict = torch.load(self.cli_args.segmentation_path)
|
|
|
+
|
|
|
+ seg_model = UNetWrapper(in_channels=8, n_classes=2, depth=5, wf=6, padding=True, batch_norm=True, up_mode='upconv')
|
|
|
+ seg_model.load_state_dict(seg_dict['model_state'])
|
|
|
+ seg_model.eval()
|
|
|
+
|
|
|
+ # cls_dict = torch.load(self.cli_args.segmentation_path)
|
|
|
+
|
|
|
+ cls_model = LunaModel()
|
|
|
+ # cls_model.load_state_dict(cls_dict['model_state'])
|
|
|
+ cls_model.eval()
|
|
|
+
|
|
|
+ if self.use_cuda:
|
|
|
+ if torch.cuda.device_count() > 1:
|
|
|
+ seg_model = nn.DataParallel(seg_model)
|
|
|
+ cls_model = nn.DataParallel(cls_model)
|
|
|
+
|
|
|
+ seg_model = seg_model.to(self.device)
|
|
|
+ cls_model = cls_model.to(self.device)
|
|
|
+
|
|
|
+ return seg_model, cls_model
|
|
|
+
|
|
|
+
|
|
|
+ def initSegmentationDl(self, series_uid):
|
|
|
+ seg_ds = Luna2dSegmentationDataset(
|
|
|
+ test_stride=10,
|
|
|
+ contextSlices_count=3,
|
|
|
+ series_uid=series_uid,
|
|
|
+ )
|
|
|
+ seg_dl = DataLoader(
|
|
|
+ seg_ds,
|
|
|
+ batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
|
|
|
+ num_workers=self.cli_args.num_workers,
|
|
|
+ pin_memory=self.use_cuda,
|
|
|
+ )
|
|
|
+
|
|
|
+ return seg_dl
|
|
|
+
|
|
|
+ def initClassificationDl(self):
|
|
|
+ seg_ds = LunaClassificationDataset(
|
|
|
+ test_stride=10,
|
|
|
+ # contextSlices_count=3,
|
|
|
+ series_uid=self.cli_args.series_uid,
|
|
|
+ )
|
|
|
+ seg_dl = DataLoader(
|
|
|
+ seg_ds,
|
|
|
+ batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
|
|
|
+ num_workers=self.cli_args.num_workers,
|
|
|
+ pin_memory=self.use_cuda,
|
|
|
+ )
|
|
|
+
|
|
|
+ return seg_dl
|
|
|
+
|
|
|
+ def main(self):
|
|
|
+ log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
|
|
|
+
|
|
|
+ if self.cli_args.series_uid:
|
|
|
+ series_list = [self.cli_args.series_uid]
|
|
|
+ else:
|
|
|
+ series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
|
|
|
+
|
|
|
+ with torch.no_grad():
|
|
|
+ series_iter = enumerateWithEstimate(
|
|
|
+ series_list,
|
|
|
+ "Series",
|
|
|
+ )
|
|
|
+ for series_ndx, series_uid in series_iter:
|
|
|
+ seg_dl = self.initSegmentationDl(series_uid)
|
|
|
+ ct = getCt(series_uid)
|
|
|
+
|
|
|
+ output_ary = np.zeros_like(ct.ary, dtype=np.float32)
|
|
|
+
|
|
|
+ # testingMetrics_tensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset))
|
|
|
+ batch_iter = enumerateWithEstimate(
|
|
|
+ seg_dl,
|
|
|
+ "Seg " + series_uid,
|
|
|
+ start_ndx=seg_dl.num_workers,
|
|
|
+ )
|
|
|
+ for batch_ndx, batch_tup in batch_iter:
|
|
|
+ # self.computeSegmentationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
|
|
|
+ input_tensor, label_tensor, _series_list, ndx_list = batch_tup
|
|
|
+
|
|
|
+ input_devtensor = input_tensor.to(self.device)
|
|
|
+
|
|
|
+ prediction_devtensor = self.seg_model(input_devtensor)
|
|
|
+
|
|
|
+ for i, sample_ndx in enumerate(ndx_list):
|
|
|
+ output_ary[sample_ndx] = prediction_devtensor[i].detatch().cpu().numpy()
|
|
|
+
|
|
|
+ irc = (output_ary > 0.5).nonzero()
|
|
|
+ xyz = irc2xyz(irc, ct.origin_xyz, ct.vxSize_xyz, ct.direction_tup)
|
|
|
+
|
|
|
+ print(irc, xyz)
|
|
|
+
|
|
|
+
|
|
|
+ #
|
|
|
+ # cls_dl = self.initClassificationDl(series_uid)
|
|
|
+ #
|
|
|
+ # # testingMetrics_tensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset))
|
|
|
+ # batch_iter = enumerateWithEstimate(
|
|
|
+ # cls_dl,
|
|
|
+ # "Cls " + series_uid,
|
|
|
+ # start_ndx=cls_dl.num_workers,
|
|
|
+ # )
|
|
|
+ # for batch_ndx, batch_tup in batch_iter:
|
|
|
+ # self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
|
|
|
+ #
|
|
|
+ #
|
|
|
+ #
|
|
|
+ #
|
|
|
+ #
|
|
|
+ #
|
|
|
+ #
|
|
|
+ #
|
|
|
+ # for epoch_ndx in range(1, self.cli_args.epochs + 1):
|
|
|
+ # train_dl = self.initTrainDl(epoch_ndx)
|
|
|
+ #
|
|
|
+ # log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
|
|
|
+ # epoch_ndx,
|
|
|
+ # self.cli_args.epochs,
|
|
|
+ # len(train_dl),
|
|
|
+ # len(test_dl),
|
|
|
+ # self.cli_args.batch_size,
|
|
|
+ # (torch.cuda.device_count() if self.use_cuda else 1),
|
|
|
+ # ))
|
|
|
+ #
|
|
|
+ # trainingMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
|
|
|
+ # if self.cli_args.segmentation:
|
|
|
+ # self.logImages(epoch_ndx, train_dl, test_dl)
|
|
|
+ #
|
|
|
+ # testingMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
|
|
|
+ # self.logMetrics(epoch_ndx, trainingMetrics_tensor, testingMetrics_tensor)
|
|
|
+ #
|
|
|
+ # self.saveModel(epoch_ndx)
|
|
|
+ #
|
|
|
+ # if hasattr(self, 'trn_writer'):
|
|
|
+ # self.trn_writer.close()
|
|
|
+ # self.tst_writer.close()
|
|
|
+
|
|
|
+ def doTraining(self, epoch_ndx, train_dl):
|
|
|
+ self.model.train()
|
|
|
+ trainingMetrics_tensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset))
|
|
|
+ train_dl.dataset.shuffleSamples()
|
|
|
+ batch_iter = enumerateWithEstimate(
|
|
|
+ train_dl,
|
|
|
+ "E{} Training".format(epoch_ndx),
|
|
|
+ start_ndx=train_dl.num_workers,
|
|
|
+ )
|
|
|
+ for batch_ndx, batch_tup in batch_iter:
|
|
|
+ self.optimizer.zero_grad()
|
|
|
+
|
|
|
+ if self.cli_args.segmentation:
|
|
|
+ loss_var = self.computeSegmentationLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
|
|
|
+ else:
|
|
|
+ loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
|
|
|
+
|
|
|
+ if loss_var is not None:
|
|
|
+ loss_var.backward()
|
|
|
+ self.optimizer.step()
|
|
|
+ del loss_var
|
|
|
+
|
|
|
+ self.totalTrainingSamples_count += trainingMetrics_tensor.size(1)
|
|
|
+
|
|
|
+ return trainingMetrics_tensor
|
|
|
+
|
|
|
+ def doTesting(self, epoch_ndx, test_dl):
|
|
|
+ with torch.no_grad():
|
|
|
+ self.model.eval()
|
|
|
+ testingMetrics_tensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset))
|
|
|
+ batch_iter = enumerateWithEstimate(
|
|
|
+ test_dl,
|
|
|
+ "E{} Testing ".format(epoch_ndx),
|
|
|
+ start_ndx=test_dl.num_workers,
|
|
|
+ )
|
|
|
+ for batch_ndx, batch_tup in batch_iter:
|
|
|
+ if self.cli_args.segmentation:
|
|
|
+ self.computeSegmentationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
|
|
|
+ else:
|
|
|
+ self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
|
|
|
+
|
|
|
+ return testingMetrics_tensor
|
|
|
+
|
|
|
+ def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
|
|
|
+ input_tensor, label_tensor, _series_list, _center_list = batch_tup
|
|
|
+
|
|
|
+ input_devtensor = input_tensor.to(self.device)
|
|
|
+ label_devtensor = label_tensor.to(self.device)
|
|
|
+
|
|
|
+ prediction_devtensor = self.model(input_devtensor)
|
|
|
+ loss_devtensor = nn.MSELoss(reduction='none')(prediction_devtensor, label_devtensor)
|
|
|
+
|
|
|
+ start_ndx = batch_ndx * batch_size
|
|
|
+ end_ndx = start_ndx + label_tensor.size(0)
|
|
|
+ metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_tensor
|
|
|
+ metrics_tensor[METRICS_PRED_NDX, start_ndx:end_ndx] = prediction_devtensor.to('cpu')
|
|
|
+ metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor.to('cpu')
|
|
|
+
|
|
|
+ # TODO: replace with torch.autograd.detect_anomaly
|
|
|
+ # assert np.isfinite(metrics_tensor).all()
|
|
|
+
|
|
|
+ return loss_devtensor.mean()
|
|
|
+
|
|
|
+ def computeSegmentationLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
|
|
|
+ input_tensor, label_tensor, _series_list, _start_list = batch_tup
|
|
|
+
|
|
|
+ # if label_tensor.max() < 0.5:
|
|
|
+ # return None
|
|
|
+
|
|
|
+ input_devtensor = input_tensor.to(self.device)
|
|
|
+ label_devtensor = label_tensor.to(self.device)
|
|
|
+
|
|
|
+ prediction_devtensor = self.model(input_devtensor)
|
|
|
+
|
|
|
+ # assert prediction_devtensor.is_contiguous()
|
|
|
+
|
|
|
+ start_ndx = batch_ndx * batch_size
|
|
|
+ end_ndx = start_ndx + label_tensor.size(0)
|
|
|
+ max2 = lambda t: t.view(t.size(0), -1).max(dim=1)[0]
|
|
|
+ intersectionSum = lambda a, b: (a * b.to(torch.float32)).view(a.size(0), -1).sum(dim=1)
|
|
|
+
|
|
|
+ diceLoss_devtensor = self.diceLoss(label_devtensor, prediction_devtensor)
|
|
|
+
|
|
|
+ with torch.no_grad():
|
|
|
+
|
|
|
+ boolPrediction_tensor = prediction_devtensor.to('cpu') > 0.5
|
|
|
+
|
|
|
+ metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = max2(label_tensor[:,0])
|
|
|
+ metrics_tensor[METRICS_MFOUND_NDX, start_ndx:end_ndx] = (max2(label_tensor[:, 0] * boolPrediction_tensor[:, 1].to(torch.float32)) > 0.5)
|
|
|
+
|
|
|
+ metrics_tensor[METRICS_MOK_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0], torch.max(boolPrediction_tensor, dim=1)[0])
|
|
|
+
|
|
|
+ metrics_tensor[METRICS_MTP_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0], boolPrediction_tensor[:,0])
|
|
|
+ metrics_tensor[METRICS_MFN_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0], ~boolPrediction_tensor[:,0])
|
|
|
+ metrics_tensor[METRICS_MFP_NDX, start_ndx:end_ndx] = intersectionSum(1 - label_tensor[:,0], boolPrediction_tensor[:,0])
|
|
|
+
|
|
|
+ metrics_tensor[METRICS_BTP_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,1], boolPrediction_tensor[:,1])
|
|
|
+ metrics_tensor[METRICS_BFN_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,1], ~boolPrediction_tensor[:,1])
|
|
|
+ metrics_tensor[METRICS_BFP_NDX, start_ndx:end_ndx] = intersectionSum(1 - label_tensor[:,1], boolPrediction_tensor[:,1])
|
|
|
+
|
|
|
+ diceLoss_tensor = diceLoss_devtensor.to('cpu')
|
|
|
+ metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_tensor
|
|
|
+
|
|
|
+ malLoss_devtensor = self.diceLoss(label_devtensor[:,0], prediction_devtensor[:,0])
|
|
|
+ malLoss_tensor = malLoss_devtensor.to('cpu')#.unsqueeze(1)
|
|
|
+ metrics_tensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = malLoss_tensor
|
|
|
+
|
|
|
+ benLoss_devtensor = self.diceLoss(label_devtensor[:,1], prediction_devtensor[:,1])
|
|
|
+ benLoss_tensor = benLoss_devtensor.to('cpu')#.unsqueeze(1)
|
|
|
+ metrics_tensor[METRICS_BEN_LOSS_NDX, start_ndx:end_ndx] = benLoss_tensor
|
|
|
+
|
|
|
+ # lungLoss_devtensor = self.diceLoss(label_devtensor[:,2], prediction_devtensor[:,2])
|
|
|
+ # lungLoss_tensor = lungLoss_devtensor.to('cpu').unsqueeze(1)
|
|
|
+ # metrics_tensor[METRICS_LUNG_LOSS_NDX, start_ndx:end_ndx] = lungLoss_tensor
|
|
|
+
|
|
|
+ # TODO: replace with torch.autograd.detect_anomaly
|
|
|
+ # assert np.isfinite(metrics_tensor).all()
|
|
|
+
|
|
|
+ # return nn.MSELoss()(prediction_devtensor, label_devtensor)
|
|
|
+
|
|
|
+ return diceLoss_devtensor.mean()
|
|
|
+ # return self.diceLoss(label_devtensor[:,0], prediction_devtensor[:,0]).mean()
|
|
|
+
|
|
|
+ def diceLoss(self, label_devtensor, prediction_devtensor, epsilon=0.01):
|
|
|
+ # sum2 = lambda t: t.sum([1,2,3,4])
|
|
|
+ sum2 = lambda t: t.view(t.size(0), -1).sum(dim=1)
|
|
|
+ # max2 = lambda t: t.view(t.size(0), -1).max(dim=1)[0]
|
|
|
+
|
|
|
+ diceCorrect_devtensor = sum2(prediction_devtensor * label_devtensor)
|
|
|
+ dicePrediction_devtensor = sum2(prediction_devtensor)
|
|
|
+ diceLabel_devtensor = sum2(label_devtensor)
|
|
|
+ epsilon_devtensor = torch.ones_like(diceCorrect_devtensor) * epsilon
|
|
|
+ diceLoss_devtensor = 1 - (2 * diceCorrect_devtensor + epsilon_devtensor) / (dicePrediction_devtensor + diceLabel_devtensor + epsilon_devtensor)
|
|
|
+
|
|
|
+ return diceLoss_devtensor
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ def logImages(self, epoch_ndx, train_dl, test_dl):
|
|
|
+ if epoch_ndx > 0: # TODO revert
|
|
|
+ self.initTensorboardWriters()
|
|
|
+
|
|
|
+ for mode_str, dl in [('trn', train_dl), ('tst', test_dl)]:
|
|
|
+ for i, series_uid in enumerate(sorted(dl.dataset.series_list)[:12]):
|
|
|
+ ct = getCt(series_uid)
|
|
|
+ noduleInfo_tup = (ct.malignantInfo_list or ct.benignInfo_list)[0]
|
|
|
+ center_irc = xyz2irc(noduleInfo_tup.center_xyz, ct.origin_xyz, ct.vxSize_xyz, ct.direction_tup)
|
|
|
+
|
|
|
+ sample_tup = dl.dataset[(series_uid, int(center_irc.index))]
|
|
|
+ input_tensor = sample_tup[0].unsqueeze(0)
|
|
|
+ label_tensor = sample_tup[1].unsqueeze(0)
|
|
|
+
|
|
|
+ input_devtensor = input_tensor.to(self.device)
|
|
|
+ label_devtensor = label_tensor.to(self.device)
|
|
|
+
|
|
|
+ prediction_devtensor = self.model(input_devtensor)
|
|
|
+ prediction_ary = prediction_devtensor.to('cpu').detach().numpy()
|
|
|
+
|
|
|
+ image_ary = np.zeros((512, 512, 3), dtype=np.float32)
|
|
|
+ image_ary[:,:,:] = (input_tensor[0,2].numpy().reshape((512,512,1))) * 0.25
|
|
|
+ image_ary[:,:,0] += prediction_ary[0,0] * 0.5
|
|
|
+ image_ary[:,:,1] += prediction_ary[0,1] * 0.25
|
|
|
+ # image_ary[:,:,2] += prediction_ary[0,2] * 0.5
|
|
|
+
|
|
|
+ # log.debug([image_ary.__array_interface__['typestr']])
|
|
|
+
|
|
|
+ # image_ary = (image_ary * 255).astype(np.uint8)
|
|
|
+
|
|
|
+ # log.debug([image_ary.__array_interface__['typestr']])
|
|
|
+
|
|
|
+ writer = getattr(self, mode_str + '_writer')
|
|
|
+ writer.add_image('{}/{}_pred'.format(mode_str, i), image_ary, self.totalTrainingSamples_count)
|
|
|
+
|
|
|
+ if epoch_ndx == 1:
|
|
|
+ label_ary = label_tensor.numpy()
|
|
|
+
|
|
|
+ image_ary = np.zeros((512, 512, 3), dtype=np.float32)
|
|
|
+ image_ary[:,:,:] = (input_tensor[0,2].numpy().reshape((512,512,1))) * 0.25
|
|
|
+ image_ary[:,:,0] += label_ary[0,0] * 0.5
|
|
|
+ image_ary[:,:,1] += label_ary[0,1] * 0.25
|
|
|
+ image_ary[:,:,2] += (input_tensor[0,-1].numpy() - (label_ary[0,0].astype(np.bool) | label_ary[0,1].astype(np.bool))) * 0.25
|
|
|
+
|
|
|
+ # log.debug([image_ary.__array_interface__['typestr']])
|
|
|
+
|
|
|
+ image_ary = (image_ary * 255).astype(np.uint8)
|
|
|
+
|
|
|
+ # log.debug([image_ary.__array_interface__['typestr']])
|
|
|
+
|
|
|
+ writer = getattr(self, mode_str + '_writer')
|
|
|
+ writer.add_image('{}/{}_label'.format(mode_str, i), image_ary, self.totalTrainingSamples_count)
|
|
|
+
|
|
|
+
|
|
|
+ def logMetrics(self,
|
|
|
+ epoch_ndx,
|
|
|
+ trainingMetrics_tensor,
|
|
|
+ testingMetrics_tensor,
|
|
|
+ classificationThreshold_float=0.5,
|
|
|
+ ):
|
|
|
+ log.info("E{} {}".format(
|
|
|
+ epoch_ndx,
|
|
|
+ type(self).__name__,
|
|
|
+ ))
|
|
|
+
|
|
|
+
|
|
|
+ for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
|
|
|
+ metrics_ary = metrics_tensor.cpu().detach().numpy()
|
|
|
+ sum_ary = metrics_ary.sum(axis=1)
|
|
|
+ assert np.isfinite(metrics_ary).all()
|
|
|
+
|
|
|
+ malLabel_mask = metrics_ary[METRICS_LABEL_NDX] > classificationThreshold_float
|
|
|
+ malFound_mask = metrics_ary[METRICS_MFOUND_NDX] > classificationThreshold_float
|
|
|
+
|
|
|
+ # malLabel_mask = ~benLabel_mask
|
|
|
+ # malPred_mask = ~benPred_mask
|
|
|
+
|
|
|
+ benLabel_count = sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]
|
|
|
+ malLabel_count = sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]
|
|
|
+
|
|
|
+ trueNeg_count = benCorrect_count = sum_ary[METRICS_BTP_NDX]
|
|
|
+ truePos_count = malCorrect_count = sum_ary[METRICS_MTP_NDX]
|
|
|
+#
|
|
|
+# falsePos_count = benLabel_count - benCorrect_count
|
|
|
+# falseNeg_count = malLabel_count - malCorrect_count
|
|
|
+
|
|
|
+
|
|
|
+ metrics_dict = {}
|
|
|
+ metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
|
|
|
+ # metrics_dict['loss/msk'] = metrics_ary[METRICS_MASKLOSS_NDX].mean()
|
|
|
+ # metrics_dict['loss/mal'] = metrics_ary[METRICS_MALLOSS_NDX].mean()
|
|
|
+ # metrics_dict['loss/lng'] = metrics_ary[METRICS_LUNG_LOSS_NDX, benLabel_mask].mean()
|
|
|
+ metrics_dict['loss/mal'] = metrics_ary[METRICS_MAL_LOSS_NDX].mean()
|
|
|
+ metrics_dict['loss/ben'] = metrics_ary[METRICS_BEN_LOSS_NDX].mean()
|
|
|
+
|
|
|
+ metrics_dict['flagged/all'] = sum_ary[METRICS_MOK_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
|
|
|
+ metrics_dict['flagged/slices'] = (malLabel_mask & malFound_mask).sum() / malLabel_mask.sum() * 100
|
|
|
+
|
|
|
+ metrics_dict['correct/mal'] = sum_ary[METRICS_MTP_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
|
|
|
+ metrics_dict['correct/ben'] = sum_ary[METRICS_BTP_NDX] / (sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]) * 100
|
|
|
+
|
|
|
+ precision = metrics_dict['pr/precision'] = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFP_NDX]) or 1)
|
|
|
+ recall = metrics_dict['pr/recall'] = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) or 1)
|
|
|
+
|
|
|
+ metrics_dict['pr/f1_score'] = 2 * (precision * recall) / ((precision + recall) or 1)
|
|
|
+
|
|
|
+ log.info(("E{} {:8} "
|
|
|
+ + "{loss/all:.4f} loss, "
|
|
|
+ + "{flagged/all:-5.1f}% pixels flagged, "
|
|
|
+ + "{flagged/slices:-5.1f}% slices flagged, "
|
|
|
+ + "{pr/precision:.4f} precision, "
|
|
|
+ + "{pr/recall:.4f} recall, "
|
|
|
+ + "{pr/f1_score:.4f} f1 score"
|
|
|
+ ).format(
|
|
|
+ epoch_ndx,
|
|
|
+ mode_str,
|
|
|
+ **metrics_dict,
|
|
|
+ ))
|
|
|
+ log.info(("E{} {:8} "
|
|
|
+ + "{loss/mal:.4f} loss, "
|
|
|
+ + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
|
|
|
+ ).format(
|
|
|
+ epoch_ndx,
|
|
|
+ mode_str + '_mal',
|
|
|
+ malCorrect_count=malCorrect_count,
|
|
|
+ malLabel_count=malLabel_count,
|
|
|
+ **metrics_dict,
|
|
|
+ ))
|
|
|
+ log.info(("E{} {:8} "
|
|
|
+ + "{loss/ben:.4f} loss, "
|
|
|
+ + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
|
|
|
+ ).format(
|
|
|
+ epoch_ndx,
|
|
|
+ mode_str + '_msk',
|
|
|
+ benCorrect_count=benCorrect_count,
|
|
|
+ benLabel_count=benLabel_count,
|
|
|
+ **metrics_dict,
|
|
|
+ ))
|
|
|
+
|
|
|
+ if epoch_ndx > 0: # TODO revert
|
|
|
+ self.initTensorboardWriters()
|
|
|
+ writer = getattr(self, mode_str + '_writer')
|
|
|
+
|
|
|
+ for key, value in metrics_dict.items():
|
|
|
+ writer.add_scalar('seg_' + key, value, self.totalTrainingSamples_count)
|
|
|
+
|
|
|
+# writer.add_pr_curve(
|
|
|
+# 'pr',
|
|
|
+# metrics_ary[METRICS_LABEL_NDX],
|
|
|
+# metrics_ary[METRICS_PRED_NDX],
|
|
|
+# self.totalTrainingSamples_count,
|
|
|
+# )
|
|
|
+
|
|
|
+# benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
|
|
|
+# malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
|
|
|
+#
|
|
|
+# bins = [x/50.0 for x in range(51)]
|
|
|
+# writer.add_histogram(
|
|
|
+# 'is_ben',
|
|
|
+# metrics_ary[METRICS_PRED_NDX, benHist_mask],
|
|
|
+# self.totalTrainingSamples_count,
|
|
|
+# bins=bins,
|
|
|
+# )
|
|
|
+# writer.add_histogram(
|
|
|
+# 'is_mal',
|
|
|
+# metrics_ary[METRICS_PRED_NDX, malHist_mask],
|
|
|
+# self.totalTrainingSamples_count,
|
|
|
+# bins=bins,
|
|
|
+# )
|
|
|
+
|
|
|
+ def saveModel(self, epoch_ndx):
|
|
|
+ file_path = os.path.join('data', 'models', self.cli_args.tb_prefix, '{}_{}.{}.state'.format(self.time_str, self.cli_args.comment, self.totalTrainingSamples_count))
|
|
|
+
|
|
|
+ os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
|
|
|
+
|
|
|
+ state = {
|
|
|
+ 'model_state': self.model.state_dict(),
|
|
|
+ 'model_name': type(self.model).__name__,
|
|
|
+ 'optimizer_state' : self.optimizer.state_dict(),
|
|
|
+ 'optimizer_name': type(self.optimizer).__name__,
|
|
|
+ 'epoch': epoch_ndx,
|
|
|
+ 'totalTrainingSamples_count': self.totalTrainingSamples_count,
|
|
|
+ # 'resumed_from': self.cli_args.resume,
|
|
|
+ }
|
|
|
+ torch.save(state, file_path)
|
|
|
+
|
|
|
+ log.debug("Saved model params to {}".format(file_path))
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ sys.exit(LunaDiagnoseApp().main() or 0)
|