diagnose.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. import argparse
  2. import datetime
  3. import os
  4. import sys
  5. import numpy as np
  6. from tensorboardX import SummaryWriter
  7. import torch
  8. import torch.nn as nn
  9. import torch.optim
  10. from torch.optim import SGD, Adam
  11. from torch.utils.data import DataLoader
  12. from util.util import enumerateWithEstimate
  13. from .dsets import Luna2dSegmentationDataset, TestingLuna2dSegmentationDataset, getCt
  14. from util.logconf import logging
  15. from util.util import xyz2irc
  16. from .model import UNetWrapper
  17. log = logging.getLogger(__name__)
  18. # log.setLevel(logging.WARN)
  19. # log.setLevel(logging.INFO)
  20. log.setLevel(logging.DEBUG)
  21. # Used for computeBatchLoss and logMetrics to index into metrics_tensor/metrics_ary
  22. # METRICS_LABEL_NDX=0
  23. # METRICS_PRED_NDX=1
  24. # METRICS_LOSS_NDX=2
  25. # METRICS_MAL_LOSS_NDX=3
  26. # METRICS_BEN_LOSS_NDX=4
  27. # METRICS_LUNG_LOSS_NDX=5
  28. # METRICS_MASKLOSS_NDX=2
  29. # METRICS_MALLOSS_NDX=3
  30. METRICS_LOSS_NDX = 0
  31. METRICS_LABEL_NDX = 1
  32. METRICS_MFOUND_NDX = 2
  33. METRICS_MOK_NDX = 3
  34. METRICS_MTP_NDX = 4
  35. METRICS_MFN_NDX = 5
  36. METRICS_MFP_NDX = 6
  37. METRICS_BTP_NDX = 7
  38. METRICS_BFN_NDX = 8
  39. METRICS_BFP_NDX = 9
  40. METRICS_MAL_LOSS_NDX = 10
  41. METRICS_BEN_LOSS_NDX = 11
  42. METRICS_SIZE = 12
  43. class LunaDiagnoseApp(object):
  44. def __init__(self, sys_argv=None):
  45. if sys_argv is None:
  46. sys_argv = sys.argv[1:]
  47. parser = argparse.ArgumentParser()
  48. parser.add_argument('--batch-size',
  49. help='Batch size to use for training',
  50. default=4,
  51. type=int,
  52. )
  53. parser.add_argument('--num-workers',
  54. help='Number of worker processes for background data loading',
  55. default=8,
  56. type=int,
  57. )
  58. parser.add_argument('--series-uid',
  59. help='Limit inference to this Series UID only.',
  60. default=None,
  61. type=str,
  62. )
  63. parser.add_argument('segmentation_path',
  64. help="Path to the saved segmentation model",
  65. nargs=1,
  66. default='none',
  67. )
  68. parser.add_argument('classification_path',
  69. help="Path to the saved classification model",
  70. nargs=1,
  71. )
  72. self.cli_args = parser.parse_args(sys_argv)
  73. # self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
  74. self.use_cuda = torch.cuda.is_available()
  75. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  76. self.seg_model, self.cls_model = self.initModels()
  77. # self.optimizer = self.initOptimizer()
  78. def initModels(self):
  79. seg_dict = torch.load(self.cli_args.segmentation_path)
  80. seg_model = UNetWrapper(in_channels=8, n_classes=2, depth=5, wf=6, padding=True, batch_norm=True, up_mode='upconv')
  81. seg_model.load_state_dict(seg_dict['model_state'])
  82. seg_model.eval()
  83. cls_model = UNetWrapper(in_channels=8, n_classes=2, depth=5, wf=6, padding=True, batch_norm=True, up_mode='upconv')
  84. cls_model.load_state_dict(seg_dict['model_state'])
  85. cls_model.eval()
  86. if self.use_cuda:
  87. if torch.cuda.device_count() > 1:
  88. seg_model = nn.DataParallel(seg_model)
  89. cls_model = nn.DataParallel(cls_model)
  90. seg_model = seg_model.to(self.device)
  91. cls_model = cls_model.to(self.device)
  92. return seg_model, cls_model
  93. def initSegmentationDl(self, epoch_ndx):
  94. seg_ds = Luna2dSegmentationDataset(
  95. test_stride=10,
  96. contextSlices_count=3,
  97. series_uid=self.cli_args.series_uid,
  98. )
  99. seg_dl = DataLoader(
  100. seg_ds,
  101. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  102. num_workers=self.cli_args.num_workers,
  103. pin_memory=self.use_cuda,
  104. )
  105. return seg_dl
  106. def initTestDl(self):
  107. if self.cli_args.segmentation:
  108. test_ds = TestingLuna2dSegmentationDataset(
  109. test_stride=10,
  110. isTestSet_bool=True,
  111. contextSlices_count=3,
  112. # scaled_bool=self.cli_args.scaled or self.cli_args.multiscaled or self.cli_args.augmented,
  113. # multiscaled_bool=self.cli_args.multiscaled,
  114. # augmented_bool=self.cli_args.augmented,
  115. )
  116. else:
  117. assert False
  118. test_dl = DataLoader(
  119. test_ds,
  120. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  121. num_workers=self.cli_args.num_workers,
  122. pin_memory=self.use_cuda,
  123. )
  124. return test_dl
  125. def main(self):
  126. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  127. seg_dl = self.initSegmentationDl()
  128. for epoch_ndx in range(1, self.cli_args.epochs + 1):
  129. train_dl = self.initTrainDl(epoch_ndx)
  130. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  131. epoch_ndx,
  132. self.cli_args.epochs,
  133. len(train_dl),
  134. len(test_dl),
  135. self.cli_args.batch_size,
  136. (torch.cuda.device_count() if self.use_cuda else 1),
  137. ))
  138. trainingMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
  139. if self.cli_args.segmentation:
  140. self.logImages(epoch_ndx, train_dl, test_dl)
  141. testingMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
  142. self.logMetrics(epoch_ndx, trainingMetrics_tensor, testingMetrics_tensor)
  143. self.saveModel(epoch_ndx)
  144. if hasattr(self, 'trn_writer'):
  145. self.trn_writer.close()
  146. self.tst_writer.close()
  147. def doTraining(self, epoch_ndx, train_dl):
  148. self.model.train()
  149. trainingMetrics_tensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset))
  150. train_dl.dataset.shuffleSamples()
  151. batch_iter = enumerateWithEstimate(
  152. train_dl,
  153. "E{} Training".format(epoch_ndx),
  154. start_ndx=train_dl.num_workers,
  155. )
  156. for batch_ndx, batch_tup in batch_iter:
  157. self.optimizer.zero_grad()
  158. if self.cli_args.segmentation:
  159. loss_var = self.computeSegmentationLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
  160. else:
  161. loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
  162. if loss_var is not None:
  163. loss_var.backward()
  164. self.optimizer.step()
  165. del loss_var
  166. self.totalTrainingSamples_count += trainingMetrics_tensor.size(1)
  167. return trainingMetrics_tensor
  168. def doTesting(self, epoch_ndx, test_dl):
  169. with torch.no_grad():
  170. self.model.eval()
  171. testingMetrics_tensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset))
  172. batch_iter = enumerateWithEstimate(
  173. test_dl,
  174. "E{} Testing ".format(epoch_ndx),
  175. start_ndx=test_dl.num_workers,
  176. )
  177. for batch_ndx, batch_tup in batch_iter:
  178. if self.cli_args.segmentation:
  179. self.computeSegmentationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
  180. else:
  181. self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
  182. return testingMetrics_tensor
  183. def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
  184. input_tensor, label_tensor, _series_list, _center_list = batch_tup
  185. input_devtensor = input_tensor.to(self.device)
  186. label_devtensor = label_tensor.to(self.device)
  187. prediction_devtensor = self.model(input_devtensor)
  188. loss_devtensor = nn.MSELoss(reduction='none')(prediction_devtensor, label_devtensor)
  189. start_ndx = batch_ndx * batch_size
  190. end_ndx = start_ndx + label_tensor.size(0)
  191. metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_tensor
  192. metrics_tensor[METRICS_PRED_NDX, start_ndx:end_ndx] = prediction_devtensor.to('cpu')
  193. metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor.to('cpu')
  194. # TODO: replace with torch.autograd.detect_anomaly
  195. # assert np.isfinite(metrics_tensor).all()
  196. return loss_devtensor.mean()
  197. def computeSegmentationLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
  198. input_tensor, label_tensor, _series_list, _start_list = batch_tup
  199. # if label_tensor.max() < 0.5:
  200. # return None
  201. input_devtensor = input_tensor.to(self.device)
  202. label_devtensor = label_tensor.to(self.device)
  203. prediction_devtensor = self.model(input_devtensor)
  204. # assert prediction_devtensor.is_contiguous()
  205. start_ndx = batch_ndx * batch_size
  206. end_ndx = start_ndx + label_tensor.size(0)
  207. max2 = lambda t: t.view(t.size(0), -1).max(dim=1)[0]
  208. intersectionSum = lambda a, b: (a * b.to(torch.float32)).view(a.size(0), -1).sum(dim=1)
  209. diceLoss_devtensor = self.diceLoss(label_devtensor, prediction_devtensor)
  210. with torch.no_grad():
  211. boolPrediction_tensor = prediction_devtensor.to('cpu') > 0.5
  212. metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = max2(label_tensor[:,0])
  213. metrics_tensor[METRICS_MFOUND_NDX, start_ndx:end_ndx] = (max2(label_tensor[:, 0] * boolPrediction_tensor[:, 1].to(torch.float32)) > 0.5)
  214. metrics_tensor[METRICS_MOK_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0], torch.max(boolPrediction_tensor, dim=1)[0])
  215. metrics_tensor[METRICS_MTP_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0], boolPrediction_tensor[:,0])
  216. metrics_tensor[METRICS_MFN_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0], ~boolPrediction_tensor[:,0])
  217. metrics_tensor[METRICS_MFP_NDX, start_ndx:end_ndx] = intersectionSum(1 - label_tensor[:,0], boolPrediction_tensor[:,0])
  218. metrics_tensor[METRICS_BTP_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,1], boolPrediction_tensor[:,1])
  219. metrics_tensor[METRICS_BFN_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,1], ~boolPrediction_tensor[:,1])
  220. metrics_tensor[METRICS_BFP_NDX, start_ndx:end_ndx] = intersectionSum(1 - label_tensor[:,1], boolPrediction_tensor[:,1])
  221. diceLoss_tensor = diceLoss_devtensor.to('cpu')
  222. metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_tensor
  223. malLoss_devtensor = self.diceLoss(label_devtensor[:,0], prediction_devtensor[:,0])
  224. malLoss_tensor = malLoss_devtensor.to('cpu')#.unsqueeze(1)
  225. metrics_tensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = malLoss_tensor
  226. benLoss_devtensor = self.diceLoss(label_devtensor[:,1], prediction_devtensor[:,1])
  227. benLoss_tensor = benLoss_devtensor.to('cpu')#.unsqueeze(1)
  228. metrics_tensor[METRICS_BEN_LOSS_NDX, start_ndx:end_ndx] = benLoss_tensor
  229. # lungLoss_devtensor = self.diceLoss(label_devtensor[:,2], prediction_devtensor[:,2])
  230. # lungLoss_tensor = lungLoss_devtensor.to('cpu').unsqueeze(1)
  231. # metrics_tensor[METRICS_LUNG_LOSS_NDX, start_ndx:end_ndx] = lungLoss_tensor
  232. # TODO: replace with torch.autograd.detect_anomaly
  233. # assert np.isfinite(metrics_tensor).all()
  234. # return nn.MSELoss()(prediction_devtensor, label_devtensor)
  235. return diceLoss_devtensor.mean()
  236. # return self.diceLoss(label_devtensor[:,0], prediction_devtensor[:,0]).mean()
  237. def diceLoss(self, label_devtensor, prediction_devtensor, epsilon=0.01):
  238. # sum2 = lambda t: t.sum([1,2,3,4])
  239. sum2 = lambda t: t.view(t.size(0), -1).sum(dim=1)
  240. # max2 = lambda t: t.view(t.size(0), -1).max(dim=1)[0]
  241. diceCorrect_devtensor = sum2(prediction_devtensor * label_devtensor)
  242. dicePrediction_devtensor = sum2(prediction_devtensor)
  243. diceLabel_devtensor = sum2(label_devtensor)
  244. epsilon_devtensor = torch.ones_like(diceCorrect_devtensor) * epsilon
  245. diceLoss_devtensor = 1 - (2 * diceCorrect_devtensor + epsilon_devtensor) / (dicePrediction_devtensor + diceLabel_devtensor + epsilon_devtensor)
  246. return diceLoss_devtensor
  247. def logImages(self, epoch_ndx, train_dl, test_dl):
  248. if epoch_ndx > 0: # TODO revert
  249. self.initTensorboardWriters()
  250. for mode_str, dl in [('trn', train_dl), ('tst', test_dl)]:
  251. for i, series_uid in enumerate(sorted(dl.dataset.series_list)[:12]):
  252. ct = getCt(series_uid)
  253. noduleInfo_tup = (ct.malignantInfo_list or ct.benignInfo_list)[0]
  254. center_irc = xyz2irc(noduleInfo_tup.center_xyz, ct.origin_xyz, ct.vxSize_xyz, ct.direction_tup)
  255. sample_tup = dl.dataset[(series_uid, int(center_irc.index))]
  256. input_tensor = sample_tup[0].unsqueeze(0)
  257. label_tensor = sample_tup[1].unsqueeze(0)
  258. input_devtensor = input_tensor.to(self.device)
  259. label_devtensor = label_tensor.to(self.device)
  260. prediction_devtensor = self.model(input_devtensor)
  261. prediction_ary = prediction_devtensor.to('cpu').detach().numpy()
  262. image_ary = np.zeros((512, 512, 3), dtype=np.float32)
  263. image_ary[:,:,:] = (input_tensor[0,2].numpy().reshape((512,512,1))) * 0.25
  264. image_ary[:,:,0] += prediction_ary[0,0] * 0.5
  265. image_ary[:,:,1] += prediction_ary[0,1] * 0.25
  266. # image_ary[:,:,2] += prediction_ary[0,2] * 0.5
  267. # log.debug([image_ary.__array_interface__['typestr']])
  268. # image_ary = (image_ary * 255).astype(np.uint8)
  269. # log.debug([image_ary.__array_interface__['typestr']])
  270. writer = getattr(self, mode_str + '_writer')
  271. writer.add_image('{}/{}_pred'.format(mode_str, i), image_ary, self.totalTrainingSamples_count)
  272. if epoch_ndx == 1:
  273. label_ary = label_tensor.numpy()
  274. image_ary = np.zeros((512, 512, 3), dtype=np.float32)
  275. image_ary[:,:,:] = (input_tensor[0,2].numpy().reshape((512,512,1))) * 0.25
  276. image_ary[:,:,0] += label_ary[0,0] * 0.5
  277. image_ary[:,:,1] += label_ary[0,1] * 0.25
  278. image_ary[:,:,2] += (input_tensor[0,-1].numpy() - (label_ary[0,0].astype(np.bool) | label_ary[0,1].astype(np.bool))) * 0.25
  279. # log.debug([image_ary.__array_interface__['typestr']])
  280. image_ary = (image_ary * 255).astype(np.uint8)
  281. # log.debug([image_ary.__array_interface__['typestr']])
  282. writer = getattr(self, mode_str + '_writer')
  283. writer.add_image('{}/{}_label'.format(mode_str, i), image_ary, self.totalTrainingSamples_count)
  284. def logMetrics(self,
  285. epoch_ndx,
  286. trainingMetrics_tensor,
  287. testingMetrics_tensor,
  288. classificationThreshold_float=0.5,
  289. ):
  290. log.info("E{} {}".format(
  291. epoch_ndx,
  292. type(self).__name__,
  293. ))
  294. for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
  295. metrics_ary = metrics_tensor.cpu().detach().numpy()
  296. sum_ary = metrics_ary.sum(axis=1)
  297. assert np.isfinite(metrics_ary).all()
  298. malLabel_mask = metrics_ary[METRICS_LABEL_NDX] > classificationThreshold_float
  299. malFound_mask = metrics_ary[METRICS_MFOUND_NDX] > classificationThreshold_float
  300. # malLabel_mask = ~benLabel_mask
  301. # malPred_mask = ~benPred_mask
  302. benLabel_count = sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]
  303. malLabel_count = sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]
  304. trueNeg_count = benCorrect_count = sum_ary[METRICS_BTP_NDX]
  305. truePos_count = malCorrect_count = sum_ary[METRICS_MTP_NDX]
  306. #
  307. # falsePos_count = benLabel_count - benCorrect_count
  308. # falseNeg_count = malLabel_count - malCorrect_count
  309. metrics_dict = {}
  310. metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
  311. # metrics_dict['loss/msk'] = metrics_ary[METRICS_MASKLOSS_NDX].mean()
  312. # metrics_dict['loss/mal'] = metrics_ary[METRICS_MALLOSS_NDX].mean()
  313. # metrics_dict['loss/lng'] = metrics_ary[METRICS_LUNG_LOSS_NDX, benLabel_mask].mean()
  314. metrics_dict['loss/mal'] = metrics_ary[METRICS_MAL_LOSS_NDX].mean()
  315. metrics_dict['loss/ben'] = metrics_ary[METRICS_BEN_LOSS_NDX].mean()
  316. metrics_dict['flagged/all'] = sum_ary[METRICS_MOK_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
  317. metrics_dict['flagged/slices'] = (malLabel_mask & malFound_mask).sum() / malLabel_mask.sum() * 100
  318. metrics_dict['correct/mal'] = sum_ary[METRICS_MTP_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
  319. metrics_dict['correct/ben'] = sum_ary[METRICS_BTP_NDX] / (sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]) * 100
  320. precision = metrics_dict['pr/precision'] = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFP_NDX]) or 1)
  321. recall = metrics_dict['pr/recall'] = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) or 1)
  322. metrics_dict['pr/f1_score'] = 2 * (precision * recall) / ((precision + recall) or 1)
  323. log.info(("E{} {:8} "
  324. + "{loss/all:.4f} loss, "
  325. + "{flagged/all:-5.1f}% pixels flagged, "
  326. + "{flagged/slices:-5.1f}% slices flagged, "
  327. + "{pr/precision:.4f} precision, "
  328. + "{pr/recall:.4f} recall, "
  329. + "{pr/f1_score:.4f} f1 score"
  330. ).format(
  331. epoch_ndx,
  332. mode_str,
  333. **metrics_dict,
  334. ))
  335. log.info(("E{} {:8} "
  336. + "{loss/mal:.4f} loss, "
  337. + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
  338. ).format(
  339. epoch_ndx,
  340. mode_str + '_mal',
  341. malCorrect_count=malCorrect_count,
  342. malLabel_count=malLabel_count,
  343. **metrics_dict,
  344. ))
  345. log.info(("E{} {:8} "
  346. + "{loss/ben:.4f} loss, "
  347. + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
  348. ).format(
  349. epoch_ndx,
  350. mode_str + '_msk',
  351. benCorrect_count=benCorrect_count,
  352. benLabel_count=benLabel_count,
  353. **metrics_dict,
  354. ))
  355. if epoch_ndx > 0: # TODO revert
  356. self.initTensorboardWriters()
  357. writer = getattr(self, mode_str + '_writer')
  358. for key, value in metrics_dict.items():
  359. writer.add_scalar('seg_' + key, value, self.totalTrainingSamples_count)
  360. # writer.add_pr_curve(
  361. # 'pr',
  362. # metrics_ary[METRICS_LABEL_NDX],
  363. # metrics_ary[METRICS_PRED_NDX],
  364. # self.totalTrainingSamples_count,
  365. # )
  366. # benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
  367. # malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
  368. #
  369. # bins = [x/50.0 for x in range(51)]
  370. # writer.add_histogram(
  371. # 'is_ben',
  372. # metrics_ary[METRICS_PRED_NDX, benHist_mask],
  373. # self.totalTrainingSamples_count,
  374. # bins=bins,
  375. # )
  376. # writer.add_histogram(
  377. # 'is_mal',
  378. # metrics_ary[METRICS_PRED_NDX, malHist_mask],
  379. # self.totalTrainingSamples_count,
  380. # bins=bins,
  381. # )
  382. def saveModel(self, epoch_ndx):
  383. file_path = os.path.join('data', 'models', self.cli_args.tb_prefix, '{}_{}.{}.state'.format(self.time_str, self.cli_args.comment, self.totalTrainingSamples_count))
  384. os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
  385. state = {
  386. 'model_state': self.model.state_dict(),
  387. 'model_name': type(self.model).__name__,
  388. 'optimizer_state' : self.optimizer.state_dict(),
  389. 'optimizer_name': type(self.optimizer).__name__,
  390. 'epoch': epoch_ndx,
  391. 'totalTrainingSamples_count': self.totalTrainingSamples_count,
  392. # 'resumed_from': self.cli_args.resume,
  393. }
  394. torch.save(state, file_path)
  395. log.debug("Saved model params to {}".format(file_path))
  396. if __name__ == '__main__':
  397. sys.exit(LunaDiagnoseApp().main() or 0)