diagnose.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593
  1. import argparse
  2. import datetime
  3. import glob
  4. import os
  5. import sys
  6. import numpy as np
  7. from tensorboardX import SummaryWriter
  8. import torch
  9. import torch.nn as nn
  10. import torch.optim
  11. from torch.optim import SGD, Adam
  12. from torch.utils.data import DataLoader
  13. from util.util import enumerateWithEstimate
  14. from .dsets import Luna2dSegmentationDataset, LunaClassificationDataset, getCt, getNoduleInfoList
  15. from util.logconf import logging
  16. from util.util import xyz2irc, irc2xyz
  17. from .model import UNetWrapper, LunaModel
  18. log = logging.getLogger(__name__)
  19. # log.setLevel(logging.WARN)
  20. # log.setLevel(logging.INFO)
  21. log.setLevel(logging.DEBUG)
  22. # Used for computeBatchLoss and logMetrics to index into metrics_tensor/metrics_ary
  23. # METRICS_LABEL_NDX=0
  24. # METRICS_PRED_NDX=1
  25. # METRICS_LOSS_NDX=2
  26. # METRICS_MAL_LOSS_NDX=3
  27. # METRICS_BEN_LOSS_NDX=4
  28. # METRICS_LUNG_LOSS_NDX=5
  29. # METRICS_MASKLOSS_NDX=2
  30. # METRICS_MALLOSS_NDX=3
  31. METRICS_LOSS_NDX = 0
  32. METRICS_LABEL_NDX = 1
  33. METRICS_MFOUND_NDX = 2
  34. METRICS_MOK_NDX = 3
  35. METRICS_MTP_NDX = 4
  36. METRICS_MFN_NDX = 5
  37. METRICS_MFP_NDX = 6
  38. METRICS_BTP_NDX = 7
  39. METRICS_BFN_NDX = 8
  40. METRICS_BFP_NDX = 9
  41. METRICS_MAL_LOSS_NDX = 10
  42. METRICS_BEN_LOSS_NDX = 11
  43. METRICS_SIZE = 12
  44. class LunaDiagnoseApp(object):
  45. def __init__(self, sys_argv=None):
  46. if sys_argv is None:
  47. log.debug(sys.argv)
  48. sys_argv = sys.argv[1:]
  49. parser = argparse.ArgumentParser()
  50. parser.add_argument('--batch-size',
  51. help='Batch size to use for training',
  52. default=4,
  53. type=int,
  54. )
  55. parser.add_argument('--num-workers',
  56. help='Number of worker processes for background data loading',
  57. default=8,
  58. type=int,
  59. )
  60. parser.add_argument('--series-uid',
  61. help='Limit inference to this Series UID only.',
  62. default=None,
  63. type=str,
  64. )
  65. parser.add_argument('segmentation_path',
  66. help="Path to the saved segmentation model",
  67. nargs='?',
  68. default=None,
  69. )
  70. parser.add_argument('classification_path',
  71. help="Path to the saved classification model",
  72. nargs='?',
  73. default=None,
  74. )
  75. parser.add_argument('--tb-prefix',
  76. default='p2ch10',
  77. help="Data prefix to use for Tensorboard run. Defaults to chapter.",
  78. )
  79. self.cli_args = parser.parse_args(sys_argv)
  80. # self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
  81. self.use_cuda = torch.cuda.is_available()
  82. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  83. # self.optimizer = self.initOptimizer()
  84. if not self.cli_args.segmentation_path:
  85. file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, 'seg_{}_{}.{}.state'.format('*', '*', 'best'))
  86. # log.debug(file_path)
  87. self.cli_args.segmentation_path = glob.glob(file_path)[-1]
  88. log.debug(self.cli_args.segmentation_path)
  89. # if not self.cli_args.classification_path:
  90. # file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, 'cls_{}_{}.{}.state'.format('*', '*', 'best'))
  91. # self.cli_args.classification_path = glob.glob(file_path)[-1]
  92. self.seg_model, self.cls_model = self.initModels()
  93. def initModels(self):
  94. log.debug(self.cli_args.segmentation_path)
  95. seg_dict = torch.load(self.cli_args.segmentation_path)
  96. seg_model = UNetWrapper(in_channels=8, n_classes=2, depth=5, wf=6, padding=True, batch_norm=True, up_mode='upconv')
  97. seg_model.load_state_dict(seg_dict['model_state'])
  98. seg_model.eval()
  99. # cls_dict = torch.load(self.cli_args.segmentation_path)
  100. cls_model = LunaModel()
  101. # cls_model.load_state_dict(cls_dict['model_state'])
  102. cls_model.eval()
  103. if self.use_cuda:
  104. if torch.cuda.device_count() > 1:
  105. seg_model = nn.DataParallel(seg_model)
  106. cls_model = nn.DataParallel(cls_model)
  107. seg_model = seg_model.to(self.device)
  108. cls_model = cls_model.to(self.device)
  109. return seg_model, cls_model
  110. def initSegmentationDl(self, series_uid):
  111. seg_ds = Luna2dSegmentationDataset(
  112. test_stride=10,
  113. contextSlices_count=3,
  114. series_uid=series_uid,
  115. )
  116. seg_dl = DataLoader(
  117. seg_ds,
  118. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  119. num_workers=self.cli_args.num_workers,
  120. pin_memory=self.use_cuda,
  121. )
  122. return seg_dl
  123. def initClassificationDl(self):
  124. seg_ds = LunaClassificationDataset(
  125. test_stride=10,
  126. # contextSlices_count=3,
  127. series_uid=self.cli_args.series_uid,
  128. )
  129. seg_dl = DataLoader(
  130. seg_ds,
  131. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  132. num_workers=self.cli_args.num_workers,
  133. pin_memory=self.use_cuda,
  134. )
  135. return seg_dl
  136. def main(self):
  137. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  138. if self.cli_args.series_uid:
  139. series_list = [self.cli_args.series_uid]
  140. else:
  141. series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
  142. with torch.no_grad():
  143. series_iter = enumerateWithEstimate(
  144. series_list,
  145. "Series",
  146. )
  147. for series_ndx, series_uid in series_iter:
  148. seg_dl = self.initSegmentationDl(series_uid)
  149. ct = getCt(series_uid)
  150. output_ary = np.zeros_like(ct.ary, dtype=np.float32)
  151. # testingMetrics_tensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset))
  152. batch_iter = enumerateWithEstimate(
  153. seg_dl,
  154. "Seg " + series_uid,
  155. start_ndx=seg_dl.num_workers,
  156. )
  157. for batch_ndx, batch_tup in batch_iter:
  158. # self.computeSegmentationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
  159. input_tensor, label_tensor, _series_list, ndx_list = batch_tup
  160. input_devtensor = input_tensor.to(self.device)
  161. prediction_devtensor = self.seg_model(input_devtensor)
  162. for i, sample_ndx in enumerate(ndx_list):
  163. output_ary[sample_ndx] = prediction_devtensor[i].detatch().cpu().numpy()
  164. irc = (output_ary > 0.5).nonzero()
  165. xyz = irc2xyz(irc, ct.origin_xyz, ct.vxSize_xyz, ct.direction_tup)
  166. print(irc, xyz)
  167. #
  168. # cls_dl = self.initClassificationDl(series_uid)
  169. #
  170. # # testingMetrics_tensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset))
  171. # batch_iter = enumerateWithEstimate(
  172. # cls_dl,
  173. # "Cls " + series_uid,
  174. # start_ndx=cls_dl.num_workers,
  175. # )
  176. # for batch_ndx, batch_tup in batch_iter:
  177. # self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
  178. #
  179. #
  180. #
  181. #
  182. #
  183. #
  184. #
  185. #
  186. # for epoch_ndx in range(1, self.cli_args.epochs + 1):
  187. # train_dl = self.initTrainDl(epoch_ndx)
  188. #
  189. # log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  190. # epoch_ndx,
  191. # self.cli_args.epochs,
  192. # len(train_dl),
  193. # len(test_dl),
  194. # self.cli_args.batch_size,
  195. # (torch.cuda.device_count() if self.use_cuda else 1),
  196. # ))
  197. #
  198. # trainingMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
  199. # if self.cli_args.segmentation:
  200. # self.logImages(epoch_ndx, train_dl, test_dl)
  201. #
  202. # testingMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
  203. # self.logMetrics(epoch_ndx, trainingMetrics_tensor, testingMetrics_tensor)
  204. #
  205. # self.saveModel(epoch_ndx)
  206. #
  207. # if hasattr(self, 'trn_writer'):
  208. # self.trn_writer.close()
  209. # self.tst_writer.close()
  210. def doTraining(self, epoch_ndx, train_dl):
  211. self.model.train()
  212. trainingMetrics_tensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset))
  213. train_dl.dataset.shuffleSamples()
  214. batch_iter = enumerateWithEstimate(
  215. train_dl,
  216. "E{} Training".format(epoch_ndx),
  217. start_ndx=train_dl.num_workers,
  218. )
  219. for batch_ndx, batch_tup in batch_iter:
  220. self.optimizer.zero_grad()
  221. if self.cli_args.segmentation:
  222. loss_var = self.computeSegmentationLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
  223. else:
  224. loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
  225. if loss_var is not None:
  226. loss_var.backward()
  227. self.optimizer.step()
  228. del loss_var
  229. self.totalTrainingSamples_count += trainingMetrics_tensor.size(1)
  230. return trainingMetrics_tensor
  231. def doTesting(self, epoch_ndx, test_dl):
  232. with torch.no_grad():
  233. self.model.eval()
  234. testingMetrics_tensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset))
  235. batch_iter = enumerateWithEstimate(
  236. test_dl,
  237. "E{} Testing ".format(epoch_ndx),
  238. start_ndx=test_dl.num_workers,
  239. )
  240. for batch_ndx, batch_tup in batch_iter:
  241. if self.cli_args.segmentation:
  242. self.computeSegmentationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
  243. else:
  244. self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
  245. return testingMetrics_tensor
  246. def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
  247. input_tensor, label_tensor, _series_list, _center_list = batch_tup
  248. input_devtensor = input_tensor.to(self.device)
  249. label_devtensor = label_tensor.to(self.device)
  250. prediction_devtensor = self.model(input_devtensor)
  251. loss_devtensor = nn.MSELoss(reduction='none')(prediction_devtensor, label_devtensor)
  252. start_ndx = batch_ndx * batch_size
  253. end_ndx = start_ndx + label_tensor.size(0)
  254. metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_tensor
  255. metrics_tensor[METRICS_PRED_NDX, start_ndx:end_ndx] = prediction_devtensor.to('cpu')
  256. metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor.to('cpu')
  257. # TODO: replace with torch.autograd.detect_anomaly
  258. # assert np.isfinite(metrics_tensor).all()
  259. return loss_devtensor.mean()
  260. def computeSegmentationLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
  261. input_tensor, label_tensor, _series_list, _start_list = batch_tup
  262. # if label_tensor.max() < 0.5:
  263. # return None
  264. input_devtensor = input_tensor.to(self.device)
  265. label_devtensor = label_tensor.to(self.device)
  266. prediction_devtensor = self.model(input_devtensor)
  267. # assert prediction_devtensor.is_contiguous()
  268. start_ndx = batch_ndx * batch_size
  269. end_ndx = start_ndx + label_tensor.size(0)
  270. max2 = lambda t: t.view(t.size(0), -1).max(dim=1)[0]
  271. intersectionSum = lambda a, b: (a * b.to(torch.float32)).view(a.size(0), -1).sum(dim=1)
  272. diceLoss_devtensor = self.diceLoss(label_devtensor, prediction_devtensor)
  273. with torch.no_grad():
  274. boolPrediction_tensor = prediction_devtensor.to('cpu') > 0.5
  275. metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = max2(label_tensor[:,0])
  276. metrics_tensor[METRICS_MFOUND_NDX, start_ndx:end_ndx] = (max2(label_tensor[:, 0] * boolPrediction_tensor[:, 1].to(torch.float32)) > 0.5)
  277. metrics_tensor[METRICS_MOK_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0], torch.max(boolPrediction_tensor, dim=1)[0])
  278. metrics_tensor[METRICS_MTP_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0], boolPrediction_tensor[:,0])
  279. metrics_tensor[METRICS_MFN_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0], ~boolPrediction_tensor[:,0])
  280. metrics_tensor[METRICS_MFP_NDX, start_ndx:end_ndx] = intersectionSum(1 - label_tensor[:,0], boolPrediction_tensor[:,0])
  281. metrics_tensor[METRICS_BTP_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,1], boolPrediction_tensor[:,1])
  282. metrics_tensor[METRICS_BFN_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,1], ~boolPrediction_tensor[:,1])
  283. metrics_tensor[METRICS_BFP_NDX, start_ndx:end_ndx] = intersectionSum(1 - label_tensor[:,1], boolPrediction_tensor[:,1])
  284. diceLoss_tensor = diceLoss_devtensor.to('cpu')
  285. metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_tensor
  286. malLoss_devtensor = self.diceLoss(label_devtensor[:,0], prediction_devtensor[:,0])
  287. malLoss_tensor = malLoss_devtensor.to('cpu')#.unsqueeze(1)
  288. metrics_tensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = malLoss_tensor
  289. benLoss_devtensor = self.diceLoss(label_devtensor[:,1], prediction_devtensor[:,1])
  290. benLoss_tensor = benLoss_devtensor.to('cpu')#.unsqueeze(1)
  291. metrics_tensor[METRICS_BEN_LOSS_NDX, start_ndx:end_ndx] = benLoss_tensor
  292. # lungLoss_devtensor = self.diceLoss(label_devtensor[:,2], prediction_devtensor[:,2])
  293. # lungLoss_tensor = lungLoss_devtensor.to('cpu').unsqueeze(1)
  294. # metrics_tensor[METRICS_LUNG_LOSS_NDX, start_ndx:end_ndx] = lungLoss_tensor
  295. # TODO: replace with torch.autograd.detect_anomaly
  296. # assert np.isfinite(metrics_tensor).all()
  297. # return nn.MSELoss()(prediction_devtensor, label_devtensor)
  298. return diceLoss_devtensor.mean()
  299. # return self.diceLoss(label_devtensor[:,0], prediction_devtensor[:,0]).mean()
  300. def diceLoss(self, label_devtensor, prediction_devtensor, epsilon=0.01):
  301. # sum2 = lambda t: t.sum([1,2,3,4])
  302. sum2 = lambda t: t.view(t.size(0), -1).sum(dim=1)
  303. # max2 = lambda t: t.view(t.size(0), -1).max(dim=1)[0]
  304. diceCorrect_devtensor = sum2(prediction_devtensor * label_devtensor)
  305. dicePrediction_devtensor = sum2(prediction_devtensor)
  306. diceLabel_devtensor = sum2(label_devtensor)
  307. epsilon_devtensor = torch.ones_like(diceCorrect_devtensor) * epsilon
  308. diceLoss_devtensor = 1 - (2 * diceCorrect_devtensor + epsilon_devtensor) / (dicePrediction_devtensor + diceLabel_devtensor + epsilon_devtensor)
  309. return diceLoss_devtensor
  310. def logImages(self, epoch_ndx, train_dl, test_dl):
  311. if epoch_ndx > 0: # TODO revert
  312. self.initTensorboardWriters()
  313. for mode_str, dl in [('trn', train_dl), ('tst', test_dl)]:
  314. for i, series_uid in enumerate(sorted(dl.dataset.series_list)[:12]):
  315. ct = getCt(series_uid)
  316. noduleInfo_tup = (ct.malignantInfo_list or ct.benignInfo_list)[0]
  317. center_irc = xyz2irc(noduleInfo_tup.center_xyz, ct.origin_xyz, ct.vxSize_xyz, ct.direction_tup)
  318. sample_tup = dl.dataset[(series_uid, int(center_irc.index))]
  319. input_tensor = sample_tup[0].unsqueeze(0)
  320. label_tensor = sample_tup[1].unsqueeze(0)
  321. input_devtensor = input_tensor.to(self.device)
  322. label_devtensor = label_tensor.to(self.device)
  323. prediction_devtensor = self.model(input_devtensor)
  324. prediction_ary = prediction_devtensor.to('cpu').detach().numpy()
  325. image_ary = np.zeros((512, 512, 3), dtype=np.float32)
  326. image_ary[:,:,:] = (input_tensor[0,2].numpy().reshape((512,512,1))) * 0.25
  327. image_ary[:,:,0] += prediction_ary[0,0] * 0.5
  328. image_ary[:,:,1] += prediction_ary[0,1] * 0.25
  329. # image_ary[:,:,2] += prediction_ary[0,2] * 0.5
  330. # log.debug([image_ary.__array_interface__['typestr']])
  331. # image_ary = (image_ary * 255).astype(np.uint8)
  332. # log.debug([image_ary.__array_interface__['typestr']])
  333. writer = getattr(self, mode_str + '_writer')
  334. writer.add_image('{}/{}_pred'.format(mode_str, i), image_ary, self.totalTrainingSamples_count)
  335. if epoch_ndx == 1:
  336. label_ary = label_tensor.numpy()
  337. image_ary = np.zeros((512, 512, 3), dtype=np.float32)
  338. image_ary[:,:,:] = (input_tensor[0,2].numpy().reshape((512,512,1))) * 0.25
  339. image_ary[:,:,0] += label_ary[0,0] * 0.5
  340. image_ary[:,:,1] += label_ary[0,1] * 0.25
  341. image_ary[:,:,2] += (input_tensor[0,-1].numpy() - (label_ary[0,0].astype(np.bool) | label_ary[0,1].astype(np.bool))) * 0.25
  342. # log.debug([image_ary.__array_interface__['typestr']])
  343. image_ary = (image_ary * 255).astype(np.uint8)
  344. # log.debug([image_ary.__array_interface__['typestr']])
  345. writer = getattr(self, mode_str + '_writer')
  346. writer.add_image('{}/{}_label'.format(mode_str, i), image_ary, self.totalTrainingSamples_count)
  347. def logMetrics(self,
  348. epoch_ndx,
  349. trainingMetrics_tensor,
  350. testingMetrics_tensor,
  351. classificationThreshold_float=0.5,
  352. ):
  353. log.info("E{} {}".format(
  354. epoch_ndx,
  355. type(self).__name__,
  356. ))
  357. for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
  358. metrics_ary = metrics_tensor.cpu().detach().numpy()
  359. sum_ary = metrics_ary.sum(axis=1)
  360. assert np.isfinite(metrics_ary).all()
  361. malLabel_mask = metrics_ary[METRICS_LABEL_NDX] > classificationThreshold_float
  362. malFound_mask = metrics_ary[METRICS_MFOUND_NDX] > classificationThreshold_float
  363. # malLabel_mask = ~benLabel_mask
  364. # malPred_mask = ~benPred_mask
  365. benLabel_count = sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]
  366. malLabel_count = sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]
  367. trueNeg_count = benCorrect_count = sum_ary[METRICS_BTP_NDX]
  368. truePos_count = malCorrect_count = sum_ary[METRICS_MTP_NDX]
  369. #
  370. # falsePos_count = benLabel_count - benCorrect_count
  371. # falseNeg_count = malLabel_count - malCorrect_count
  372. metrics_dict = {}
  373. metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
  374. # metrics_dict['loss/msk'] = metrics_ary[METRICS_MASKLOSS_NDX].mean()
  375. # metrics_dict['loss/mal'] = metrics_ary[METRICS_MALLOSS_NDX].mean()
  376. # metrics_dict['loss/lng'] = metrics_ary[METRICS_LUNG_LOSS_NDX, benLabel_mask].mean()
  377. metrics_dict['loss/mal'] = metrics_ary[METRICS_MAL_LOSS_NDX].mean()
  378. metrics_dict['loss/ben'] = metrics_ary[METRICS_BEN_LOSS_NDX].mean()
  379. metrics_dict['flagged/all'] = sum_ary[METRICS_MOK_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
  380. metrics_dict['flagged/slices'] = (malLabel_mask & malFound_mask).sum() / malLabel_mask.sum() * 100
  381. metrics_dict['correct/mal'] = sum_ary[METRICS_MTP_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
  382. metrics_dict['correct/ben'] = sum_ary[METRICS_BTP_NDX] / (sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]) * 100
  383. precision = metrics_dict['pr/precision'] = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFP_NDX]) or 1)
  384. recall = metrics_dict['pr/recall'] = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) or 1)
  385. metrics_dict['pr/f1_score'] = 2 * (precision * recall) / ((precision + recall) or 1)
  386. log.info(("E{} {:8} "
  387. + "{loss/all:.4f} loss, "
  388. + "{flagged/all:-5.1f}% pixels flagged, "
  389. + "{flagged/slices:-5.1f}% slices flagged, "
  390. + "{pr/precision:.4f} precision, "
  391. + "{pr/recall:.4f} recall, "
  392. + "{pr/f1_score:.4f} f1 score"
  393. ).format(
  394. epoch_ndx,
  395. mode_str,
  396. **metrics_dict,
  397. ))
  398. log.info(("E{} {:8} "
  399. + "{loss/mal:.4f} loss, "
  400. + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
  401. ).format(
  402. epoch_ndx,
  403. mode_str + '_mal',
  404. malCorrect_count=malCorrect_count,
  405. malLabel_count=malLabel_count,
  406. **metrics_dict,
  407. ))
  408. log.info(("E{} {:8} "
  409. + "{loss/ben:.4f} loss, "
  410. + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
  411. ).format(
  412. epoch_ndx,
  413. mode_str + '_msk',
  414. benCorrect_count=benCorrect_count,
  415. benLabel_count=benLabel_count,
  416. **metrics_dict,
  417. ))
  418. if epoch_ndx > 0: # TODO revert
  419. self.initTensorboardWriters()
  420. writer = getattr(self, mode_str + '_writer')
  421. for key, value in metrics_dict.items():
  422. writer.add_scalar('seg_' + key, value, self.totalTrainingSamples_count)
  423. # writer.add_pr_curve(
  424. # 'pr',
  425. # metrics_ary[METRICS_LABEL_NDX],
  426. # metrics_ary[METRICS_PRED_NDX],
  427. # self.totalTrainingSamples_count,
  428. # )
  429. # benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
  430. # malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
  431. #
  432. # bins = [x/50.0 for x in range(51)]
  433. # writer.add_histogram(
  434. # 'is_ben',
  435. # metrics_ary[METRICS_PRED_NDX, benHist_mask],
  436. # self.totalTrainingSamples_count,
  437. # bins=bins,
  438. # )
  439. # writer.add_histogram(
  440. # 'is_mal',
  441. # metrics_ary[METRICS_PRED_NDX, malHist_mask],
  442. # self.totalTrainingSamples_count,
  443. # bins=bins,
  444. # )
  445. def saveModel(self, epoch_ndx):
  446. file_path = os.path.join('data', 'models', self.cli_args.tb_prefix, '{}_{}.{}.state'.format(self.time_str, self.cli_args.comment, self.totalTrainingSamples_count))
  447. os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
  448. state = {
  449. 'model_state': self.model.state_dict(),
  450. 'model_name': type(self.model).__name__,
  451. 'optimizer_state' : self.optimizer.state_dict(),
  452. 'optimizer_name': type(self.optimizer).__name__,
  453. 'epoch': epoch_ndx,
  454. 'totalTrainingSamples_count': self.totalTrainingSamples_count,
  455. # 'resumed_from': self.cli_args.resume,
  456. }
  457. torch.save(state, file_path)
  458. log.debug("Saved model params to {}".format(file_path))
  459. if __name__ == '__main__':
  460. sys.exit(LunaDiagnoseApp().main() or 0)