training.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. import argparse
  2. import datetime
  3. import os
  4. import socket
  5. import sys
  6. import numpy as np
  7. from tensorboardX import SummaryWriter
  8. import torch
  9. import torch.nn as nn
  10. import torch.optim
  11. from torch.optim import SGD, Adam
  12. from torch.utils.data import DataLoader
  13. from util.util import enumerateWithEstimate
  14. from .dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt
  15. from util.logconf import logging
  16. from util.util import xyz2irc
  17. from .model import UNetWrapper
  18. log = logging.getLogger(__name__)
  19. # log.setLevel(logging.WARN)
  20. # log.setLevel(logging.INFO)
  21. log.setLevel(logging.DEBUG)
  22. # Used for computeClassificationLoss and logMetrics to index into metrics_tensor/metrics_ary
  23. METRICS_LABEL_NDX = 0
  24. METRICS_LOSS_NDX = 1
  25. METRICS_MAL_LOSS_NDX = 2
  26. METRICS_BEN_LOSS_NDX = 3
  27. METRICS_MTP_NDX = 4
  28. METRICS_MFN_NDX = 5
  29. METRICS_MFP_NDX = 6
  30. METRICS_BTP_NDX = 7
  31. METRICS_BFN_NDX = 8
  32. # METRICS_BFP_NDX = 9
  33. METRICS_SIZE = 9
  34. class LunaTrainingApp(object):
  35. def __init__(self, sys_argv=None):
  36. if sys_argv is None:
  37. sys_argv = sys.argv[1:]
  38. parser = argparse.ArgumentParser()
  39. parser.add_argument('--batch-size',
  40. help='Batch size to use for training',
  41. default=24,
  42. type=int,
  43. )
  44. parser.add_argument('--num-workers',
  45. help='Number of worker processes for background data loading',
  46. default=8,
  47. type=int,
  48. )
  49. parser.add_argument('--epochs',
  50. help='Number of epochs to train for',
  51. default=1,
  52. type=int,
  53. )
  54. parser.add_argument('--augmented',
  55. help="Augment the training data.",
  56. action='store_true',
  57. default=False,
  58. )
  59. parser.add_argument('--augment-flip',
  60. help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
  61. action='store_true',
  62. default=False,
  63. )
  64. # parser.add_argument('--augment-offset',
  65. # help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
  66. # action='store_true',
  67. # default=False,
  68. # )
  69. # parser.add_argument('--augment-scale',
  70. # help="Augment the training data by randomly increasing or decreasing the size of the nodule.",
  71. # action='store_true',
  72. # default=False,
  73. # )
  74. parser.add_argument('--augment-rotate',
  75. help="Augment the training data by randomly rotating the data around the head-foot axis.",
  76. action='store_true',
  77. default=False,
  78. )
  79. parser.add_argument('--augment-noise',
  80. help="Augment the training data by randomly adding noise to the data.",
  81. action='store_true',
  82. default=False,
  83. )
  84. parser.add_argument('--tb-prefix',
  85. default='p2ch12',
  86. help="Data prefix to use for Tensorboard run. Defaults to chapter.",
  87. )
  88. parser.add_argument('comment',
  89. help="Comment suffix for Tensorboard run.",
  90. nargs='?',
  91. default='none',
  92. )
  93. self.cli_args = parser.parse_args(sys_argv)
  94. self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
  95. self.trn_writer = None
  96. self.tst_writer = None
  97. self.use_cuda = torch.cuda.is_available()
  98. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  99. # # TODO: remove this if block before print
  100. # # This is due to an odd setup that the author is using to test the code; please ignore for now
  101. # if socket.gethostname() == 'c2':
  102. # self.device = torch.device("cuda:1")
  103. self.model = self.initModel()
  104. self.optimizer = self.initOptimizer()
  105. self.totalTrainingSamples_count = 0
  106. augmentation_dict = {}
  107. if self.cli_args.augmented or self.cli_args.augment_flip:
  108. augmentation_dict['flip'] = True
  109. if self.cli_args.augmented or self.cli_args.augment_rotate:
  110. augmentation_dict['rotate'] = True
  111. if self.cli_args.augmented or self.cli_args.augment_noise:
  112. augmentation_dict['noise'] = 25.0
  113. self.augmentation_dict = augmentation_dict
  114. def initModel(self):
  115. # model = UNetWrapper(in_channels=8, n_classes=2, depth=3, wf=6, padding=True, batch_norm=True, up_mode='upconv')
  116. model = UNetWrapper(in_channels=7, n_classes=1, depth=4, wf=3, padding=True, batch_norm=True, up_mode='upconv')
  117. if self.use_cuda:
  118. if torch.cuda.device_count() > 1:
  119. model = nn.DataParallel(model)
  120. model = model.to(self.device)
  121. return model
  122. def initOptimizer(self):
  123. return SGD(self.model.parameters(), lr=0.01, momentum=0.99)
  124. # return Adam(self.model.parameters())
  125. def initTrainDl(self):
  126. train_ds = TrainingLuna2dSegmentationDataset(
  127. test_stride=10,
  128. contextSlices_count=3,
  129. augmentation_dict=self.augmentation_dict,
  130. )
  131. train_dl = DataLoader(
  132. train_ds,
  133. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  134. num_workers=self.cli_args.num_workers,
  135. pin_memory=self.use_cuda,
  136. )
  137. return train_dl
  138. def initTestDl(self):
  139. test_ds = Luna2dSegmentationDataset(
  140. test_stride=10,
  141. contextSlices_count=3,
  142. )
  143. test_dl = DataLoader(
  144. test_ds,
  145. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  146. num_workers=self.cli_args.num_workers,
  147. pin_memory=self.use_cuda,
  148. )
  149. return test_dl
  150. def initTensorboardWriters(self):
  151. if self.trn_writer is None:
  152. log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
  153. self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_seg_' + self.cli_args.comment)
  154. self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_seg_' + self.cli_args.comment)
  155. def main(self):
  156. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  157. train_dl = self.initTrainDl()
  158. test_dl = self.initTestDl()
  159. self.initTensorboardWriters()
  160. # self.logModelMetrics(self.model)
  161. best_score = 0.0
  162. for epoch_ndx in range(1, self.cli_args.epochs + 1):
  163. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  164. epoch_ndx,
  165. self.cli_args.epochs,
  166. len(train_dl),
  167. len(test_dl),
  168. self.cli_args.batch_size,
  169. (torch.cuda.device_count() if self.use_cuda else 1),
  170. ))
  171. trainingMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
  172. self.logPerformanceMetrics(epoch_ndx, 'trn', trainingMetrics_tensor)
  173. self.logImages(epoch_ndx, train_dl, test_dl)
  174. # self.logModelMetrics(self.model)
  175. testingMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
  176. score = self.logPerformanceMetrics(epoch_ndx, 'tst', testingMetrics_tensor)
  177. best_score = max(score, best_score)
  178. self.saveModel('seg' if self.cli_args.segmentation else 'cls', epoch_ndx, score == best_score)
  179. if hasattr(self, 'trn_writer'):
  180. self.trn_writer.close()
  181. self.tst_writer.close()
  182. def doTraining(self, epoch_ndx, train_dl):
  183. trainingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
  184. self.model.train()
  185. batch_iter = enumerateWithEstimate(
  186. train_dl,
  187. "E{} Training".format(epoch_ndx),
  188. start_ndx=train_dl.num_workers,
  189. )
  190. for batch_ndx, batch_tup in batch_iter:
  191. self.optimizer.zero_grad()
  192. loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_devtensor)
  193. loss_var.backward()
  194. self.optimizer.step()
  195. del loss_var
  196. self.totalTrainingSamples_count += trainingMetrics_devtensor.size(1)
  197. return trainingMetrics_devtensor.to('cpu')
  198. def doTesting(self, epoch_ndx, test_dl):
  199. with torch.no_grad():
  200. testingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset)).to(self.device)
  201. self.model.eval()
  202. batch_iter = enumerateWithEstimate(
  203. test_dl,
  204. "E{} Testing ".format(epoch_ndx),
  205. start_ndx=test_dl.num_workers,
  206. )
  207. for batch_ndx, batch_tup in batch_iter:
  208. self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_devtensor)
  209. return testingMetrics_devtensor.to('cpu')
  210. def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_devtensor):
  211. input_tensor, label_tensor, label_list, ben_tensor, mal_tensor, _series_list, _start_list = batch_tup
  212. input_devtensor = input_tensor.to(self.device, non_blocking=True)
  213. label_devtensor = label_tensor.to(self.device, non_blocking=True)
  214. mal_devtensor = mal_tensor.to(self.device, non_blocking=True)
  215. ben_devtensor = ben_tensor.to(self.device, non_blocking=True)
  216. start_ndx = batch_ndx * batch_size
  217. end_ndx = start_ndx + label_tensor.size(0)
  218. intersectionSum = lambda a, b: (a * b).view(a.size(0), -1).sum(dim=1)
  219. prediction_devtensor = self.model(input_devtensor)
  220. diceLoss_devtensor = self.diceLoss(label_devtensor, prediction_devtensor)
  221. with torch.no_grad():
  222. predictionBool_devtensor = (prediction_devtensor > 0.5).to(torch.float32)
  223. metrics_devtensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_list
  224. metrics_devtensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_devtensor
  225. malPred_devtensor = predictionBool_devtensor * (1 - ben_devtensor)
  226. tp = intersectionSum( mal_devtensor, malPred_devtensor)
  227. fn = intersectionSum( mal_devtensor, 1 - malPred_devtensor)
  228. fp = intersectionSum(1 - mal_devtensor, malPred_devtensor)
  229. ls = self.diceLoss(mal_devtensor, malPred_devtensor)
  230. metrics_devtensor[METRICS_MTP_NDX, start_ndx:end_ndx] = tp
  231. metrics_devtensor[METRICS_MFN_NDX, start_ndx:end_ndx] = fn
  232. metrics_devtensor[METRICS_MFP_NDX, start_ndx:end_ndx] = fp
  233. metrics_devtensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = ls
  234. del malPred_devtensor, tp, fn, fp, ls
  235. benPred_devtensor = predictionBool_devtensor * (1 - mal_devtensor)
  236. tp = intersectionSum( ben_devtensor, benPred_devtensor)
  237. fn = intersectionSum( ben_devtensor, 1 - benPred_devtensor)
  238. # fp = intersectionSum(1 - ben_devtensor, benPred_devtensor)
  239. ls = self.diceLoss(ben_devtensor, benPred_devtensor)
  240. metrics_devtensor[METRICS_BTP_NDX, start_ndx:end_ndx] = tp
  241. metrics_devtensor[METRICS_BFN_NDX, start_ndx:end_ndx] = fn
  242. # metrics_devtensor[METRICS_BFP_NDX, start_ndx:end_ndx] = fp
  243. metrics_devtensor[METRICS_BEN_LOSS_NDX, start_ndx:end_ndx] = ls
  244. del benPred_devtensor, tp, fn, ls
  245. return diceLoss_devtensor.mean()
  246. def diceLoss(self, label_devtensor, prediction_devtensor, epsilon=0.01, p=False):
  247. sum_dim1 = lambda t: t.view(t.size(0), -1).sum(dim=1)
  248. diceLabel_devtensor = sum_dim1(label_devtensor)
  249. dicePrediction_devtensor = sum_dim1(prediction_devtensor)
  250. diceCorrect_devtensor = sum_dim1(prediction_devtensor * label_devtensor)
  251. epsilon_devtensor = torch.ones_like(diceCorrect_devtensor) * epsilon
  252. diceLoss_devtensor = 1 - (2 * diceCorrect_devtensor + epsilon_devtensor) / (dicePrediction_devtensor + diceLabel_devtensor + epsilon_devtensor)
  253. return diceLoss_devtensor
  254. def logImages(self, epoch_ndx, train_dl, test_dl):
  255. for mode_str, dl in [('trn', train_dl), ('tst', test_dl)]:
  256. for i, series_uid in enumerate(sorted(dl.dataset.series_list)[:12]):
  257. ct = getCt(series_uid)
  258. noduleInfo_tup = (ct.malignantInfo_list or ct.benignInfo_list)[0]
  259. center_irc = xyz2irc(noduleInfo_tup.center_xyz, ct.origin_xyz, ct.vxSize_xyz, ct.direction_tup)
  260. sample_tup = dl.dataset[(series_uid, int(center_irc.index))]
  261. # input_tensor = sample_tup[0].unsqueeze(0)
  262. # label_tensor = sample_tup[1].unsqueeze(0)
  263. input_tensor, label_tensor, ben_tensor, mal_tensor = sample_tup[:4]
  264. input_tensor += 1000
  265. input_tensor /= 2001
  266. input_devtensor = input_tensor.to(self.device)
  267. # label_devtensor = label_tensor.to(self.device)
  268. prediction_devtensor = self.model(input_devtensor.unsqueeze(0))[0]
  269. prediction_ary = prediction_devtensor.to('cpu').detach().numpy()
  270. label_ary = label_tensor.numpy()
  271. ben_ary = ben_tensor.numpy()
  272. mal_ary = mal_tensor.numpy()
  273. # log.debug([prediction_ary[0].shape, label_ary.shape, mal_ary.shape])
  274. image_ary = np.zeros((512, 512, 3), dtype=np.float32)
  275. image_ary[:,:,:] = (input_tensor[dl.dataset.contextSlices_count].numpy().reshape((512,512,1))) * 0.5
  276. image_ary[:,:,0] += prediction_ary[0] * (1 - label_ary[0]) * 0.5
  277. image_ary[:,:,1] += prediction_ary[0] * mal_ary * 0.5
  278. image_ary[:,:,2] += prediction_ary[0] * ben_ary * 0.5
  279. # image_ary[:,:,2] += prediction_ary[0,1] * 0.25
  280. # image_ary[:,:,2] += prediction_ary[0,2] * 0.5
  281. # log.debug([image_ary.__array_interface__['typestr']])
  282. # image_ary = (image_ary * 255).astype(np.uint8)
  283. # log.debug([image_ary.__array_interface__['typestr']])
  284. writer = getattr(self, mode_str + '_writer')
  285. try:
  286. image_ary[image_ary < 0] = 0
  287. image_ary[image_ary > 1] = 1
  288. writer.add_image('{}/{}_pred'.format(mode_str, i), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
  289. except:
  290. log.debug([image_ary.shape, image_ary.dtype])
  291. raise
  292. if epoch_ndx == 1:
  293. image_ary = np.zeros((512, 512, 3), dtype=np.float32)
  294. image_ary[:,:,:] = (input_tensor[dl.dataset.contextSlices_count].numpy().reshape((512,512,1))) * 0.5
  295. image_ary[:,:,1] += mal_ary * 0.5
  296. image_ary[:,:,2] += ben_ary * 0.5
  297. # image_ary[:,:,2] += label_ary[0,1] * 0.25
  298. # image_ary[:,:,2] += (input_tensor[0,-1].numpy() - (label_ary[0,0].astype(np.bool) | label_ary[0,1].astype(np.bool))) * 0.25
  299. # log.debug([image_ary.__array_interface__['typestr']])
  300. # image_ary = (image_ary * 255).astype(np.uint8)
  301. # log.debug([image_ary.__array_interface__['typestr']])
  302. writer = getattr(self, mode_str + '_writer')
  303. image_ary[image_ary < 0] = 0
  304. image_ary[image_ary > 1] = 1
  305. writer.add_image('{}/{}_label'.format(mode_str, i), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
  306. def logPerformanceMetrics(self,
  307. epoch_ndx,
  308. mode_str,
  309. metrics_tensor,
  310. ):
  311. log.info("E{} {}".format(
  312. epoch_ndx,
  313. type(self).__name__,
  314. ))
  315. # for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
  316. metrics_ary = metrics_tensor.cpu().detach().numpy()
  317. sum_ary = metrics_ary.sum(axis=1)
  318. assert np.isfinite(metrics_ary).all()
  319. malLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 1) | (metrics_ary[METRICS_LABEL_NDX] == 3)
  320. benLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 2) | (metrics_ary[METRICS_LABEL_NDX] == 3)
  321. # malFound_mask = metrics_ary[METRICS_MFOUND_NDX] > classificationThreshold_float
  322. # malLabel_mask = ~benLabel_mask
  323. # malPred_mask = ~benPred_mask
  324. benLabel_count = sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]
  325. malLabel_count = sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]
  326. trueNeg_count = benCorrect_count = sum_ary[METRICS_BTP_NDX]
  327. truePos_count = malCorrect_count = sum_ary[METRICS_MTP_NDX]
  328. #
  329. # falsePos_count = benLabel_count - benCorrect_count
  330. # falseNeg_count = malLabel_count - malCorrect_count
  331. metrics_dict = {}
  332. metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
  333. # metrics_dict['loss/msk'] = metrics_ary[METRICS_MASKLOSS_NDX].mean()
  334. # metrics_dict['loss/mal'] = metrics_ary[METRICS_MALLOSS_NDX].mean()
  335. # metrics_dict['loss/lng'] = metrics_ary[METRICS_LUNG_LOSS_NDX, benLabel_mask].mean()
  336. metrics_dict['loss/mal'] = np.nan_to_num(metrics_ary[METRICS_MAL_LOSS_NDX, malLabel_mask].mean())
  337. metrics_dict['loss/ben'] = metrics_ary[METRICS_BEN_LOSS_NDX, benLabel_mask].mean()
  338. # metrics_dict['loss/flg'] = metrics_ary[METRICS_FLG_LOSS_NDX].mean()
  339. # metrics_dict['flagged/all'] = sum_ary[METRICS_MOK_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
  340. # metrics_dict['flagged/slices'] = (malLabel_mask & malFound_mask).sum() / malLabel_mask.sum() * 100
  341. metrics_dict['correct/mal'] = sum_ary[METRICS_MTP_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
  342. metrics_dict['correct/ben'] = sum_ary[METRICS_BTP_NDX] / (sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]) * 100
  343. precision = metrics_dict['pr/precision'] = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFP_NDX]) or 1)
  344. recall = metrics_dict['pr/recall'] = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) or 1)
  345. metrics_dict['pr/f1_score'] = 2 * (precision * recall) / ((precision + recall) or 1)
  346. log.info(("E{} {:8} "
  347. + "{loss/all:.4f} loss, "
  348. # + "{loss/flg:.4f} flagged loss, "
  349. # + "{flagged/all:-5.1f}% pixels flagged, "
  350. # + "{flagged/slices:-5.1f}% slices flagged, "
  351. + "{pr/precision:.4f} precision, "
  352. + "{pr/recall:.4f} recall, "
  353. + "{pr/f1_score:.4f} f1 score"
  354. ).format(
  355. epoch_ndx,
  356. mode_str,
  357. **metrics_dict,
  358. ))
  359. log.info(("E{} {:8} "
  360. + "{loss/mal:.4f} loss, "
  361. + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
  362. ).format(
  363. epoch_ndx,
  364. mode_str + '_mal',
  365. malCorrect_count=malCorrect_count,
  366. malLabel_count=malLabel_count,
  367. **metrics_dict,
  368. ))
  369. log.info(("E{} {:8} "
  370. + "{loss/ben:.4f} loss, "
  371. + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
  372. ).format(
  373. epoch_ndx,
  374. mode_str + '_ben',
  375. benCorrect_count=benCorrect_count,
  376. benLabel_count=benLabel_count,
  377. **metrics_dict,
  378. ))
  379. writer = getattr(self, mode_str + '_writer')
  380. prefix_str = 'seg_'
  381. for key, value in metrics_dict.items():
  382. writer.add_scalar(prefix_str + key, value, self.totalTrainingSamples_count)
  383. score = 1 \
  384. + metrics_dict['pr/f1_score'] \
  385. - metrics_dict['loss/mal'] * 0.01 \
  386. - metrics_dict['loss/all'] * 0.0001
  387. return score
  388. # def logModelMetrics(self, model):
  389. # writer = getattr(self, 'trn_writer')
  390. #
  391. # model = getattr(model, 'module', model)
  392. #
  393. # for name, param in model.named_parameters():
  394. # if param.requires_grad:
  395. # min_data = float(param.data.min())
  396. # max_data = float(param.data.max())
  397. # max_extent = max(abs(min_data), abs(max_data))
  398. #
  399. # # bins = [x/50*max_extent for x in range(-50, 51)]
  400. #
  401. # writer.add_histogram(
  402. # name.rsplit('.', 1)[-1] + '/' + name,
  403. # param.data.cpu().numpy(),
  404. # # metrics_ary[METRICS_PRED_NDX, benHist_mask],
  405. # self.totalTrainingSamples_count,
  406. # # bins=bins,
  407. # )
  408. #
  409. # # print name, param.data
  410. def saveModel(self, type_str, epoch_ndx, isBest=False):
  411. file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, '{}_{}_{}.{}.state'.format(type_str, self.time_str, self.cli_args.comment, self.totalTrainingSamples_count))
  412. os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
  413. model = self.model
  414. if hasattr(model, 'module'):
  415. model = model.module
  416. state = {
  417. 'model_state': model.state_dict(),
  418. 'model_name': type(model).__name__,
  419. 'optimizer_state' : self.optimizer.state_dict(),
  420. 'optimizer_name': type(self.optimizer).__name__,
  421. 'epoch': epoch_ndx,
  422. 'totalTrainingSamples_count': self.totalTrainingSamples_count,
  423. # 'resumed_from': self.cli_args.resume,
  424. }
  425. torch.save(state, file_path)
  426. log.debug("Saved model params to {}".format(file_path))
  427. if isBest:
  428. file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, '{}_{}_{}.{}.state'.format(type_str, self.time_str, self.cli_args.comment, 'best'))
  429. torch.save(state, file_path)
  430. log.debug("Saved model params to {}".format(file_path))
  431. if __name__ == '__main__':
  432. sys.exit(LunaTrainingApp().main() or 0)