train_seg.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. import argparse
  2. import datetime
  3. import os
  4. import socket
  5. import sys
  6. import numpy as np
  7. from torch.utils.tensorboard import SummaryWriter
  8. import torch
  9. import torch.nn as nn
  10. import torch.optim
  11. from torch.optim import SGD, Adam
  12. from torch.utils.data import DataLoader
  13. from util.util import enumerateWithEstimate
  14. from .dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt
  15. from util.logconf import logging
  16. from util.util import xyz2irc
  17. from .model_seg import UNetWrapper
  18. log = logging.getLogger(__name__)
  19. # log.setLevel(logging.WARN)
  20. # log.setLevel(logging.INFO)
  21. log.setLevel(logging.DEBUG)
  22. # Used for computeClassificationLoss and logMetrics to index into metrics_t/metrics_a
  23. METRICS_LABEL_NDX = 0
  24. METRICS_LOSS_NDX = 1
  25. METRICS_MAL_LOSS_NDX = 2
  26. # METRICS_ALL_LOSS_NDX = 3
  27. METRICS_MTP_NDX = 4
  28. METRICS_MFN_NDX = 5
  29. METRICS_MFP_NDX = 6
  30. METRICS_ATP_NDX = 7
  31. METRICS_AFN_NDX = 8
  32. METRICS_AFP_NDX = 9
  33. METRICS_SIZE = 10
  34. class LunaTrainingApp(object):
  35. def __init__(self, sys_argv=None):
  36. if sys_argv is None:
  37. sys_argv = sys.argv[1:]
  38. parser = argparse.ArgumentParser()
  39. parser.add_argument('--batch-size',
  40. help='Batch size to use for training',
  41. default=16,
  42. type=int,
  43. )
  44. parser.add_argument('--num-workers',
  45. help='Number of worker processes for background data loading',
  46. default=8,
  47. type=int,
  48. )
  49. parser.add_argument('--epochs',
  50. help='Number of epochs to train for',
  51. default=1,
  52. type=int,
  53. )
  54. parser.add_argument('--augmented',
  55. help="Augment the training data.",
  56. action='store_true',
  57. default=False,
  58. )
  59. parser.add_argument('--augment-flip',
  60. help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
  61. action='store_true',
  62. default=False,
  63. )
  64. # parser.add_argument('--augment-offset',
  65. # help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
  66. # action='store_true',
  67. # default=False,
  68. # )
  69. # parser.add_argument('--augment-scale',
  70. # help="Augment the training data by randomly increasing or decreasing the size of the nodule.",
  71. # action='store_true',
  72. # default=False,
  73. # )
  74. parser.add_argument('--augment-rotate',
  75. help="Augment the training data by randomly rotating the data around the head-foot axis.",
  76. action='store_true',
  77. default=False,
  78. )
  79. parser.add_argument('--augment-noise',
  80. help="Augment the training data by randomly adding noise to the data.",
  81. action='store_true',
  82. default=False,
  83. )
  84. parser.add_argument('--tb-prefix',
  85. default='p2ch13',
  86. help="Data prefix to use for Tensorboard run. Defaults to chapter.",
  87. )
  88. parser.add_argument('comment',
  89. help="Comment suffix for Tensorboard run.",
  90. nargs='?',
  91. default='none',
  92. )
  93. self.cli_args = parser.parse_args(sys_argv)
  94. self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
  95. self.totalTrainingSamples_count = 0
  96. self.trn_writer = None
  97. self.val_writer = None
  98. augmentation_dict = {}
  99. if self.cli_args.augmented or self.cli_args.augment_flip:
  100. augmentation_dict['flip'] = True
  101. if self.cli_args.augmented or self.cli_args.augment_rotate:
  102. augmentation_dict['rotate'] = True
  103. if self.cli_args.augmented or self.cli_args.augment_noise:
  104. augmentation_dict['noise'] = 0.025
  105. self.augmentation_dict = augmentation_dict
  106. self.use_cuda = torch.cuda.is_available()
  107. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  108. self.model = self.initModel()
  109. self.optimizer = self.initOptimizer()
  110. def initModel(self):
  111. model = UNetWrapper(
  112. in_channels=8,
  113. n_classes=1,
  114. depth=4,
  115. wf=3,
  116. padding=True,
  117. batch_norm=True,
  118. up_mode='upconv',
  119. )
  120. if self.use_cuda:
  121. if torch.cuda.device_count() > 1:
  122. model = nn.DataParallel(model)
  123. model = model.to(self.device)
  124. return model
  125. def initOptimizer(self):
  126. return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
  127. # return Adam(self.model.parameters())
  128. def initTrainDl(self):
  129. train_ds = TrainingLuna2dSegmentationDataset(
  130. val_stride=10,
  131. isValSet_bool=False,
  132. contextSlices_count=3,
  133. augmentation_dict=self.augmentation_dict,
  134. )
  135. train_dl = DataLoader(
  136. train_ds,
  137. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  138. num_workers=self.cli_args.num_workers,
  139. pin_memory=self.use_cuda,
  140. )
  141. return train_dl
  142. def initValDl(self):
  143. val_ds = Luna2dSegmentationDataset(
  144. val_stride=10,
  145. isValSet_bool=True,
  146. contextSlices_count=3,
  147. )
  148. val_dl = DataLoader(
  149. val_ds,
  150. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  151. num_workers=self.cli_args.num_workers,
  152. pin_memory=self.use_cuda,
  153. )
  154. return val_dl
  155. def initTensorboardWriters(self):
  156. if self.trn_writer is None:
  157. log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
  158. self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_seg_' + self.cli_args.comment)
  159. self.val_writer = SummaryWriter(log_dir=log_dir + '_val_seg_' + self.cli_args.comment)
  160. def main(self):
  161. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  162. train_dl = self.initTrainDl()
  163. val_dl = self.initValDl()
  164. # self.logModelMetrics(self.model)
  165. best_score = 0.0
  166. for epoch_ndx in range(1, self.cli_args.epochs + 1):
  167. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  168. epoch_ndx,
  169. self.cli_args.epochs,
  170. len(train_dl),
  171. len(val_dl),
  172. self.cli_args.batch_size,
  173. (torch.cuda.device_count() if self.use_cuda else 1),
  174. ))
  175. trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
  176. self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
  177. self.logImages(epoch_ndx, 'trn', train_dl)
  178. self.logImages(epoch_ndx, 'val', val_dl)
  179. # self.logModelMetrics(self.model)
  180. valMetrics_t = self.doValidation(epoch_ndx, val_dl)
  181. score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
  182. best_score = max(score, best_score)
  183. self.saveModel('seg', epoch_ndx, score == best_score)
  184. if hasattr(self, 'trn_writer'):
  185. self.trn_writer.close()
  186. self.val_writer.close()
  187. def doTraining(self, epoch_ndx, train_dl):
  188. trnMetrics_g = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
  189. self.model.train()
  190. batch_iter = enumerateWithEstimate(
  191. train_dl,
  192. "E{} Training".format(epoch_ndx),
  193. start_ndx=train_dl.num_workers,
  194. )
  195. for batch_ndx, batch_tup in batch_iter:
  196. self.optimizer.zero_grad()
  197. loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trnMetrics_g)
  198. loss_var.backward()
  199. self.optimizer.step()
  200. del loss_var
  201. self.totalTrainingSamples_count += trnMetrics_g.size(1)
  202. return trnMetrics_g.to('cpu')
  203. def doValidation(self, epoch_ndx, val_dl):
  204. with torch.no_grad():
  205. valMetrics_g = torch.zeros(METRICS_SIZE, len(val_dl.dataset)).to(self.device)
  206. self.model.eval()
  207. batch_iter = enumerateWithEstimate(
  208. val_dl,
  209. "E{} Validation ".format(epoch_ndx),
  210. start_ndx=val_dl.num_workers,
  211. )
  212. for batch_ndx, batch_tup in batch_iter:
  213. self.computeBatchLoss(batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)
  214. return valMetrics_g.to('cpu')
  215. def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):
  216. input_t, label_t, label_list, ben_t, mal_t, _, _ = batch_tup
  217. input_g = input_t.to(self.device, non_blocking=True)
  218. label_g = label_t.to(self.device, non_blocking=True)
  219. mal_g = mal_t.to(self.device, non_blocking=True)
  220. ben_g = ben_t.to(self.device, non_blocking=True)
  221. start_ndx = batch_ndx * batch_size
  222. end_ndx = start_ndx + label_t.size(0)
  223. intersectionSum = lambda a, b: (a * b).view(a.size(0), -1).sum(dim=1)
  224. prediction_g = self.model(input_g)
  225. diceLoss_g = self.diceLoss(label_g, prediction_g)
  226. with torch.no_grad():
  227. malLoss_g = self.diceLoss(mal_g, prediction_g * mal_g, p=True)
  228. predictionBool_g = (prediction_g > 0.5).to(torch.float32)
  229. metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_list
  230. metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_g
  231. metrics_g[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = malLoss_g
  232. malPred_g = predictionBool_g * mal_g
  233. tp = intersectionSum( mal_g, malPred_g)
  234. fn = intersectionSum( mal_g, 1 - malPred_g)
  235. metrics_g[METRICS_MTP_NDX, start_ndx:end_ndx] = tp
  236. metrics_g[METRICS_MFN_NDX, start_ndx:end_ndx] = fn
  237. del malPred_g, tp, fn
  238. tp = intersectionSum( label_g, predictionBool_g)
  239. fn = intersectionSum( label_g, 1 - predictionBool_g)
  240. fp = intersectionSum(1 - label_g, predictionBool_g)
  241. metrics_g[METRICS_ATP_NDX, start_ndx:end_ndx] = tp
  242. metrics_g[METRICS_AFN_NDX, start_ndx:end_ndx] = fn
  243. metrics_g[METRICS_AFP_NDX, start_ndx:end_ndx] = fp
  244. del tp, fn, fp
  245. return diceLoss_g.mean()
  246. # def diceLoss(self, label_g, prediction_g, epsilon=0.01, p=False):
  247. def diceLoss(self, label_g, prediction_g, epsilon=1, p=False):
  248. sum_dim1 = lambda t: t.view(t.size(0), -1).sum(dim=1)
  249. diceLabel_g = sum_dim1(label_g)
  250. dicePrediction_g = sum_dim1(prediction_g)
  251. diceCorrect_g = sum_dim1(prediction_g * label_g)
  252. epsilon_g = torch.ones_like(diceCorrect_g) * epsilon
  253. diceLoss_g = 1 - (2 * diceCorrect_g + epsilon_g) \
  254. / (dicePrediction_g + diceLabel_g + epsilon_g)
  255. if p and diceLoss_g.mean() < 0:
  256. correct_tmp = prediction_g * label_g
  257. log.debug([])
  258. log.debug(['diceCorrect_g ', diceCorrect_g[0].item(), correct_tmp[0].min().item(), correct_tmp[0].mean().item(), correct_tmp[0].max().item(), correct_tmp.shape])
  259. log.debug(['dicePrediction_g', dicePrediction_g[0].item(), prediction_g[0].min().item(), prediction_g[0].mean().item(), prediction_g[0].max().item(), prediction_g.shape])
  260. log.debug(['diceLabel_g ', diceLabel_g[0].item(), label_g[0].min().item(), label_g[0].mean().item(), label_g[0].max().item(), label_g.shape])
  261. log.debug(['2*diceCorrect_g ', 2 * diceCorrect_g[0].item()])
  262. log.debug(['Prediction + Label ', dicePrediction_g[0].item()])
  263. log.debug(['diceLoss_g ', diceLoss_g[0].item()])
  264. assert False
  265. return diceLoss_g
  266. def logImages(self, epoch_ndx, mode_str, dl):
  267. images_iter = sorted(dl.dataset.series_list)[:12]
  268. for series_ndx, series_uid in enumerate(images_iter):
  269. ct = getCt(series_uid)
  270. for slice_ndx in range(6):
  271. ct_ndx = slice_ndx * ct.hu_a.shape[0] // 5
  272. ct_ndx = min(ct_ndx, ct.hu_a.shape[0] - 1)
  273. sample_tup = dl.dataset[(series_uid, ct_ndx, False)]
  274. ct_t, nodule_t, _, ben_t, mal_t, _, _ = sample_tup
  275. ct_t[:-1,:,:] += 1
  276. ct_t[:-1,:,:] /= 2
  277. input_g = ct_t.to(self.device)
  278. label_g = nodule_t.to(self.device)
  279. prediction_g = self.model(input_g.unsqueeze(0))[0]
  280. prediction_a = prediction_g.to('cpu').detach().numpy()
  281. label_a = nodule_t.numpy()
  282. ben_a = ben_t.numpy()
  283. mal_a = mal_t.numpy()
  284. ctSlice_a = ct_t[dl.dataset.contextSlices_count].numpy()
  285. image_a = np.zeros((512, 512, 3), dtype=np.float32)
  286. image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
  287. image_a[:,:,0] += prediction_a[0] * (1 - label_a[0])
  288. image_a[:,:,1] += prediction_a[0] * mal_a[0]
  289. image_a[:,:,2] += prediction_a[0] * ben_a[0]
  290. image_a *= 0.5
  291. image_a[image_a < 0] = 0
  292. image_a[image_a > 1] = 1
  293. writer = getattr(self, mode_str + '_writer')
  294. writer.add_image(
  295. '{}/{}_prediction_{}'.format(
  296. mode_str,
  297. series_ndx,
  298. slice_ndx,
  299. ),
  300. image_a,
  301. self.totalTrainingSamples_count,
  302. dataformats='HWC',
  303. )
  304. # self.diceLoss(label_g, prediction_g, p=True)
  305. if epoch_ndx == 1:
  306. image_a = np.zeros((512, 512, 3), dtype=np.float32)
  307. image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
  308. image_a[:,:,0] += (1 - label_a[0]) * ct_t[-1].numpy() # Red
  309. image_a[:,:,1] += mal_a[0] # Green
  310. image_a[:,:,2] += ben_a[0] # Blue
  311. image_a *= 0.5
  312. image_a[image_a < 0] = 0
  313. image_a[image_a > 1] = 1
  314. writer.add_image(
  315. '{}/{}_label_{}'.format(
  316. mode_str,
  317. series_ndx,
  318. slice_ndx,
  319. ),
  320. image_a,
  321. self.totalTrainingSamples_count,
  322. dataformats='HWC',
  323. )
  324. def logMetrics(self,
  325. epoch_ndx,
  326. mode_str,
  327. metrics_t,
  328. ):
  329. log.info("E{} {}".format(
  330. epoch_ndx,
  331. type(self).__name__,
  332. ))
  333. metrics_a = metrics_t.detach().numpy()
  334. sum_a = metrics_a.sum(axis=1)
  335. assert np.isfinite(metrics_a).all()
  336. malLabel_mask = (metrics_a[METRICS_LABEL_NDX] == 1) | (metrics_a[METRICS_LABEL_NDX] == 3)
  337. # allLabel_mask = (metrics_a[METRICS_LABEL_NDX] == 2) | (metrics_a[METRICS_LABEL_NDX] == 3)
  338. allLabel_count = sum_a[METRICS_ATP_NDX] + sum_a[METRICS_AFN_NDX]
  339. malLabel_count = sum_a[METRICS_MTP_NDX] + sum_a[METRICS_MFN_NDX]
  340. # allCorrect_count = sum_a[METRICS_ATP_NDX]
  341. # malCorrect_count = sum_a[METRICS_MTP_NDX]
  342. #
  343. # falsePos_count = allLabel_count - allCorrect_count
  344. # falseNeg_count = malLabel_count - malCorrect_count
  345. metrics_dict = {}
  346. metrics_dict['loss/all'] = metrics_a[METRICS_LOSS_NDX].mean()
  347. metrics_dict['loss/mal'] = np.nan_to_num(metrics_a[METRICS_MAL_LOSS_NDX, malLabel_mask].mean())
  348. # metrics_dict['loss/all'] = metrics_a[METRICS_ALL_LOSS_NDX, allLabel_mask].mean()
  349. # metrics_dict['correct/mal'] = sum_a[METRICS_MTP_NDX] / (sum_a[METRICS_MTP_NDX] + sum_a[METRICS_MFN_NDX]) * 100
  350. # metrics_dict['correct/all'] = sum_a[METRICS_ATP_NDX] / (sum_a[METRICS_ATP_NDX] + sum_a[METRICS_AFN_NDX]) * 100
  351. metrics_dict['percent_all/tp'] = sum_a[METRICS_ATP_NDX] / (allLabel_count or 1) * 100
  352. metrics_dict['percent_all/fn'] = sum_a[METRICS_AFN_NDX] / (allLabel_count or 1) * 100
  353. metrics_dict['percent_all/fp'] = sum_a[METRICS_AFP_NDX] / (allLabel_count or 1) * 100
  354. metrics_dict['percent_mal/tp'] = sum_a[METRICS_MTP_NDX] / (malLabel_count or 1) * 100
  355. metrics_dict['percent_mal/fn'] = sum_a[METRICS_MFN_NDX] / (malLabel_count or 1) * 100
  356. precision = metrics_dict['pr/precision'] = sum_a[METRICS_ATP_NDX] \
  357. / ((sum_a[METRICS_ATP_NDX] + sum_a[METRICS_AFP_NDX]) or 1)
  358. recall = metrics_dict['pr/recall'] = sum_a[METRICS_ATP_NDX] \
  359. / ((sum_a[METRICS_ATP_NDX] + sum_a[METRICS_AFN_NDX]) or 1)
  360. metrics_dict['pr/f1_score'] = 2 * (precision * recall) \
  361. / ((precision + recall) or 1)
  362. log.info(("E{} {:8} "
  363. + "{loss/all:.4f} loss, "
  364. + "{pr/precision:.4f} precision, "
  365. + "{pr/recall:.4f} recall, "
  366. + "{pr/f1_score:.4f} f1 score"
  367. ).format(
  368. epoch_ndx,
  369. mode_str,
  370. **metrics_dict,
  371. ))
  372. log.info(("E{} {:8} "
  373. + "{loss/all:.4f} loss, "
  374. + "{percent_all/tp:-5.1f}% tp, {percent_all/fn:-5.1f}% fn, {percent_all/fp:-9.1f}% fp"
  375. # + "{correct/all:-5.1f}% correct ({allCorrect_count:} of {allLabel_count:})"
  376. ).format(
  377. epoch_ndx,
  378. mode_str + '_all',
  379. # allCorrect_count=allCorrect_count,
  380. # allLabel_count=allLabel_count,
  381. **metrics_dict,
  382. ))
  383. log.info(("E{} {:8} "
  384. + "{loss/mal:.4f} loss, "
  385. + "{percent_mal/tp:-5.1f}% tp, {percent_mal/fn:-5.1f}% fn"
  386. # + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
  387. ).format(
  388. epoch_ndx,
  389. mode_str + '_mal',
  390. # malCorrect_count=malCorrect_count,
  391. # malLabel_count=malLabel_count,
  392. **metrics_dict,
  393. ))
  394. self.initTensorboardWriters()
  395. writer = getattr(self, mode_str + '_writer')
  396. prefix_str = 'seg_'
  397. for key, value in metrics_dict.items():
  398. writer.add_scalar(prefix_str + key, value, self.totalTrainingSamples_count)
  399. score = 1 \
  400. - metrics_dict['loss/mal'] \
  401. + metrics_dict['pr/f1_score'] \
  402. - metrics_dict['pr/recall'] * 0.01 \
  403. - metrics_dict['loss/all'] * 0.0001
  404. return score
  405. # def logModelMetrics(self, model):
  406. # writer = getattr(self, 'trn_writer')
  407. #
  408. # model = getattr(model, 'module', model)
  409. #
  410. # for name, param in model.named_parameters():
  411. # if param.requires_grad:
  412. # min_data = float(param.data.min())
  413. # max_data = float(param.data.max())
  414. # max_extent = max(abs(min_data), abs(max_data))
  415. #
  416. # # bins = [x/50*max_extent for x in range(-50, 51)]
  417. #
  418. # writer.add_histogram(
  419. # name.rsplit('.', 1)[-1] + '/' + name,
  420. # param.data.cpu().numpy(),
  421. # # metrics_a[METRICS_PRED_NDX, benHist_mask],
  422. # self.totalTrainingSamples_count,
  423. # # bins=bins,
  424. # )
  425. #
  426. # # print name, param.data
  427. def saveModel(self, type_str, epoch_ndx, isBest=False):
  428. file_path = os.path.join(
  429. 'data-unversioned',
  430. 'part2',
  431. 'models',
  432. self.cli_args.tb_prefix,
  433. '{}_{}_{}.{}.state'.format(
  434. type_str,
  435. self.time_str,
  436. self.cli_args.comment,
  437. self.totalTrainingSamples_count,
  438. )
  439. )
  440. os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
  441. model = self.model
  442. if hasattr(model, 'module'):
  443. model = model.module
  444. state = {
  445. 'model_state': model.state_dict(),
  446. 'model_name': type(model).__name__,
  447. 'optimizer_state' : self.optimizer.state_dict(),
  448. 'optimizer_name': type(self.optimizer).__name__,
  449. 'epoch': epoch_ndx,
  450. 'totalTrainingSamples_count': self.totalTrainingSamples_count,
  451. }
  452. torch.save(state, file_path)
  453. log.debug("Saved model params to {}".format(file_path))
  454. if isBest:
  455. file_path = os.path.join(
  456. 'data-unversioned',
  457. 'part2',
  458. 'models',
  459. self.cli_args.tb_prefix,
  460. '{}_{}_{}.{}.state'.format(
  461. type_str,
  462. self.time_str,
  463. self.cli_args.comment,
  464. 'best',
  465. )
  466. )
  467. torch.save(state, file_path)
  468. log.debug("Saved model params to {}".format(file_path))
  469. if __name__ == '__main__':
  470. LunaTrainingApp().main()