training.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710
  1. import argparse
  2. import datetime
  3. import os
  4. import socket
  5. import sys
  6. import numpy as np
  7. from tensorboardX import SummaryWriter
  8. import torch
  9. import torch.nn as nn
  10. import torch.optim
  11. from torch.optim import SGD, Adam
  12. from torch.utils.data import DataLoader
  13. from util.util import enumerateWithEstimate
  14. from .dsets import TrainingLuna2dSegmentationDataset, TestingLuna2dSegmentationDataset, LunaClassificationDataset, getCt
  15. from util.logconf import logging
  16. from util.util import xyz2irc
  17. from .model import UNetWrapper, LunaModel
  18. log = logging.getLogger(__name__)
  19. # log.setLevel(logging.WARN)
  20. # log.setLevel(logging.INFO)
  21. log.setLevel(logging.DEBUG)
  22. # Used for computeClassificationLoss and logMetrics to index into metrics_tensor/metrics_ary
  23. # METRICS_LABEL_NDX=0
  24. # METRICS_PRED_NDX=1
  25. # METRICS_LOSS_NDX=2
  26. # METRICS_MAL_LOSS_NDX=3
  27. # METRICS_BEN_LOSS_NDX=4
  28. # METRICS_LUNG_LOSS_NDX=5
  29. # METRICS_MASKLOSS_NDX=2
  30. # METRICS_MALLOSS_NDX=3
  31. METRICS_LOSS_NDX = 0
  32. METRICS_LABEL_NDX = 1
  33. METRICS_PRED_NDX = 2
  34. METRICS_MTP_NDX = 3
  35. METRICS_MFN_NDX = 4
  36. METRICS_MFP_NDX = 5
  37. METRICS_BTP_NDX = 6
  38. METRICS_BFN_NDX = 7
  39. METRICS_BFP_NDX = 8
  40. METRICS_MAL_LOSS_NDX = 9
  41. METRICS_BEN_LOSS_NDX = 10
  42. # METRICS_MFOUND_NDX = 2
  43. # METRICS_MOK_NDX = 2
  44. # METRICS_FLG_LOSS_NDX = 10
  45. METRICS_SIZE = 11
  46. class LunaTrainingApp(object):
  47. def __init__(self, sys_argv=None):
  48. if sys_argv is None:
  49. sys_argv = sys.argv[1:]
  50. parser = argparse.ArgumentParser()
  51. parser.add_argument('--batch-size',
  52. help='Batch size to use for training',
  53. default=4,
  54. type=int,
  55. )
  56. parser.add_argument('--num-workers',
  57. help='Number of worker processes for background data loading',
  58. default=8,
  59. type=int,
  60. )
  61. parser.add_argument('--epochs',
  62. help='Number of epochs to train for',
  63. default=1,
  64. type=int,
  65. )
  66. # parser.add_argument('--resume',
  67. # default=None,
  68. # help="File to resume training from.",
  69. # )
  70. parser.add_argument('--segmentation',
  71. help="TODO", # TODO
  72. action='store_true',
  73. default=False,
  74. )
  75. parser.add_argument('--balanced',
  76. help="Balance the training data to half benign, half malignant.",
  77. action='store_true',
  78. default=False,
  79. )
  80. parser.add_argument('--adaptive',
  81. help="Balance the training data to start half benign, half malignant, and end at a 100:1 ratio.",
  82. action='store_true',
  83. default=False,
  84. )
  85. parser.add_argument('--scaled',
  86. help="Scale the CT chunks to square voxels.",
  87. action='store_true',
  88. default=False,
  89. )
  90. parser.add_argument('--multiscaled',
  91. help="Scale the CT chunks to square voxels.",
  92. action='store_true',
  93. default=False,
  94. )
  95. parser.add_argument('--augmented',
  96. help="Augment the training data (implies --scaled).",
  97. action='store_true',
  98. default=False,
  99. )
  100. parser.add_argument('--tb-prefix',
  101. default='p2ch10',
  102. help="Data prefix to use for Tensorboard run. Defaults to chapter.",
  103. )
  104. parser.add_argument('comment',
  105. help="Comment suffix for Tensorboard run.",
  106. nargs='?',
  107. default='none',
  108. )
  109. self.cli_args = parser.parse_args(sys_argv)
  110. self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
  111. self.trn_writer = None
  112. self.tst_writer = None
  113. self.use_cuda = torch.cuda.is_available()
  114. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  115. # TODO: remove this if block before print
  116. # This is due to an odd setup that the author is using to test the code; please ignore for now
  117. if socket.gethostname() == 'c2':
  118. self.device = torch.device("cuda:1")
  119. self.model = self.initModel()
  120. self.optimizer = self.initOptimizer()
  121. self.totalTrainingSamples_count = 0
  122. def initModel(self):
  123. if self.cli_args.segmentation:
  124. model = UNetWrapper(in_channels=8, n_classes=2, depth=5, wf=6, padding=True, batch_norm=True, up_mode='upconv')
  125. else:
  126. model = LunaModel()
  127. if self.use_cuda:
  128. if torch.cuda.device_count() > 1:
  129. # TODO: remove this if block before print
  130. # This is due to an odd setup that the author is using to test the code; please ignore for now
  131. if socket.gethostname() == 'c2':
  132. model = nn.DataParallel(model, device_ids=[1, 0])
  133. else:
  134. model = nn.DataParallel(model)
  135. model = model.to(self.device)
  136. return model
  137. def initOptimizer(self):
  138. return SGD(self.model.parameters(), lr=0.01, momentum=0.99)
  139. # return Adam(self.model.parameters())
  140. def initTrainDl(self):
  141. if self.cli_args.segmentation:
  142. train_ds = TrainingLuna2dSegmentationDataset(
  143. test_stride=10,
  144. contextSlices_count=3,
  145. )
  146. else:
  147. train_ds = LunaClassificationDataset(
  148. test_stride=10,
  149. isTestSet_bool=False,
  150. # series_uid=None,
  151. # sortby_str='random',
  152. ratio_int=int(self.cli_args.balanced),
  153. # scaled_bool=False,
  154. # multiscaled_bool=False,
  155. # augmented_bool=False,
  156. # noduleInfo_list=None,
  157. )
  158. train_dl = DataLoader(
  159. train_ds,
  160. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  161. num_workers=self.cli_args.num_workers,
  162. pin_memory=self.use_cuda,
  163. )
  164. return train_dl
  165. def initTestDl(self):
  166. if self.cli_args.segmentation:
  167. test_ds = TestingLuna2dSegmentationDataset(
  168. test_stride=10,
  169. contextSlices_count=3,
  170. )
  171. else:
  172. test_ds = LunaClassificationDataset(
  173. test_stride=10,
  174. isTestSet_bool=True,
  175. # series_uid=None,
  176. # sortby_str='random',
  177. # ratio_int=int(self.cli_args.balanced),
  178. # scaled_bool=False,
  179. # multiscaled_bool=False,
  180. # augmented_bool=False,
  181. # noduleInfo_list=None,
  182. )
  183. test_dl = DataLoader(
  184. test_ds,
  185. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  186. num_workers=self.cli_args.num_workers,
  187. pin_memory=self.use_cuda,
  188. )
  189. return test_dl
  190. def initTensorboardWriters(self):
  191. if self.trn_writer is None:
  192. log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
  193. type_str = 'seg_' if self.cli_args.segmentation else 'cls_'
  194. self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_' + type_str + self.cli_args.comment)
  195. self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_' + type_str + self.cli_args.comment)
  196. def main(self):
  197. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  198. train_dl = self.initTrainDl()
  199. test_dl = self.initTestDl()
  200. self.initTensorboardWriters()
  201. self.logModelMetrics(self.model)
  202. best_score = 0.0
  203. for epoch_ndx in range(1, self.cli_args.epochs + 1):
  204. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  205. epoch_ndx,
  206. self.cli_args.epochs,
  207. len(train_dl),
  208. len(test_dl),
  209. self.cli_args.batch_size,
  210. (torch.cuda.device_count() if self.use_cuda else 1),
  211. ))
  212. trainingMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
  213. if epoch_ndx > 0:
  214. self.logPerformanceMetrics(epoch_ndx, 'trn', trainingMetrics_tensor)
  215. self.logModelMetrics(self.model)
  216. if self.cli_args.segmentation:
  217. self.logImages(epoch_ndx, train_dl, test_dl)
  218. testingMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
  219. score = self.logPerformanceMetrics(epoch_ndx, 'tst', testingMetrics_tensor)
  220. best_score = max(score, best_score)
  221. self.saveModel('seg' if self.cli_args.segmentation else 'cls', epoch_ndx, score == best_score)
  222. if hasattr(self, 'trn_writer'):
  223. self.trn_writer.close()
  224. self.tst_writer.close()
  225. def doTraining(self, epoch_ndx, train_dl):
  226. self.model.train()
  227. trainingMetrics_tensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset))
  228. # train_dl.dataset.shuffleSamples()
  229. batch_iter = enumerateWithEstimate(
  230. train_dl,
  231. "E{} Training".format(epoch_ndx),
  232. start_ndx=train_dl.num_workers,
  233. )
  234. for batch_ndx, batch_tup in batch_iter:
  235. self.optimizer.zero_grad()
  236. if self.cli_args.segmentation:
  237. loss_var = self.computeSegmentationLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
  238. else:
  239. loss_var = self.computeClassificationLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
  240. if loss_var is not None:
  241. loss_var.backward()
  242. self.optimizer.step()
  243. del loss_var
  244. self.totalTrainingSamples_count += trainingMetrics_tensor.size(1)
  245. return trainingMetrics_tensor
  246. def doTesting(self, epoch_ndx, test_dl):
  247. with torch.no_grad():
  248. self.model.eval()
  249. testingMetrics_tensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset))
  250. batch_iter = enumerateWithEstimate(
  251. test_dl,
  252. "E{} Testing ".format(epoch_ndx),
  253. start_ndx=test_dl.num_workers,
  254. )
  255. for batch_ndx, batch_tup in batch_iter:
  256. if self.cli_args.segmentation:
  257. self.computeSegmentationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
  258. else:
  259. self.computeClassificationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
  260. return testingMetrics_tensor
  261. def computeClassificationLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
  262. input_tensor, label_tensor, _series_list, _center_list = batch_tup
  263. input_devtensor = input_tensor.to(self.device)
  264. label_devtensor = label_tensor.to(self.device)
  265. prediction_devtensor = self.model(input_devtensor)
  266. loss_devtensor = nn.MSELoss(reduction='none')(prediction_devtensor, label_devtensor)
  267. start_ndx = batch_ndx * batch_size
  268. end_ndx = start_ndx + label_tensor.size(0)
  269. with torch.no_grad():
  270. # log.debug([metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx].shape, label_tensor.shape])
  271. metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_tensor[:,0]
  272. metrics_tensor[METRICS_PRED_NDX, start_ndx:end_ndx] = prediction_devtensor.to('cpu')[:,0]
  273. # metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor.to('cpu')
  274. prediction_tensor = prediction_devtensor.to('cpu', non_blocking=True)
  275. loss_tensor = loss_devtensor.to('cpu', non_blocking=True)[:,0]
  276. malLabel_tensor = (label_tensor > 0.5)[:,0]
  277. benLabel_tensor = ~malLabel_tensor
  278. malPred_tensor = prediction_tensor > 0.5
  279. benPred_tensor = ~malPred_tensor
  280. metrics_tensor[METRICS_MTP_NDX, start_ndx:end_ndx] = (malLabel_tensor * malPred_tensor).sum(dim=1)
  281. metrics_tensor[METRICS_MFN_NDX, start_ndx:end_ndx] = (malLabel_tensor * benPred_tensor).sum(dim=1)
  282. metrics_tensor[METRICS_MFP_NDX, start_ndx:end_ndx] = (benLabel_tensor * malPred_tensor).sum(dim=1)
  283. metrics_tensor[METRICS_BTP_NDX, start_ndx:end_ndx] = (benLabel_tensor * benPred_tensor).sum(dim=1)
  284. metrics_tensor[METRICS_BFN_NDX, start_ndx:end_ndx] = (benLabel_tensor * malPred_tensor).sum(dim=1)
  285. metrics_tensor[METRICS_BFP_NDX, start_ndx:end_ndx] = (malLabel_tensor * benPred_tensor).sum(dim=1)
  286. metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_tensor
  287. metrics_tensor[METRICS_BEN_LOSS_NDX, start_ndx:end_ndx] = loss_tensor * benLabel_tensor.type(torch.float32)
  288. metrics_tensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = loss_tensor * malLabel_tensor.type(torch.float32)
  289. # TODO: replace with torch.autograd.detect_anomaly
  290. # assert np.isfinite(metrics_tensor).all()
  291. return loss_devtensor.mean()
  292. def computeSegmentationLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
  293. input_tensor, label_tensor, _series_list, _start_list = batch_tup
  294. # if label_tensor.max() < 0.5:
  295. # return None
  296. input_devtensor = input_tensor.to(self.device)
  297. label_devtensor = label_tensor.to(self.device)
  298. prediction_devtensor = self.model(input_devtensor)
  299. # assert prediction_devtensor.is_contiguous()
  300. start_ndx = batch_ndx * batch_size
  301. end_ndx = start_ndx + label_tensor.size(0)
  302. max2 = lambda t: t.view(t.size(0), -1).max(dim=1)[0]
  303. intersectionSum = lambda a, b: (a * b.to(torch.float32)).view(a.size(0), -1).sum(dim=1)
  304. diceLoss_devtensor = self.diceLoss(label_devtensor, prediction_devtensor)
  305. malLoss_devtensor = self.diceLoss(label_devtensor[:,0], prediction_devtensor[:,0])
  306. benLoss_devtensor = self.diceLoss(label_devtensor[:,1], prediction_devtensor[:,1])
  307. with torch.no_grad():
  308. bPred_tensor = prediction_devtensor.to('cpu', non_blocking=True)
  309. diceLoss_tensor = diceLoss_devtensor.to('cpu', non_blocking=True)
  310. malLoss_tensor = malLoss_devtensor.to('cpu', non_blocking=True)
  311. benLoss_tensor = benLoss_devtensor.to('cpu', non_blocking=True)
  312. # flgLoss_devtensor = self.diceLoss(label_devtensor[:,0], label_devtensor[:,0] * prediction_devtensor[:,1])
  313. # flgLoss_tensor = flgLoss_devtensor.to('cpu', non_blocking=True)#.unsqueeze(1)
  314. metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = max2(label_tensor[:,0]) + max2(label_tensor[:,1]) * 2
  315. # metrics_tensor[METRICS_MFOUND_NDX, start_ndx:end_ndx] = (max2(label_tensor[:, 0] * bPred_tensor[:, 1].to(torch.float32)) > 0.5)
  316. # metrics_tensor[METRICS_MOK_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0], bPred_tensor[:,1])
  317. bPred_tensor = bPred_tensor > 0.5
  318. metrics_tensor[METRICS_MTP_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0], bPred_tensor[:,0])
  319. metrics_tensor[METRICS_MFN_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0], ~bPred_tensor[:,0])
  320. metrics_tensor[METRICS_MFP_NDX, start_ndx:end_ndx] = intersectionSum(1 - label_tensor[:,0], bPred_tensor[:,0])
  321. metrics_tensor[METRICS_BTP_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,1], bPred_tensor[:,1])
  322. metrics_tensor[METRICS_BFN_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,1], ~bPred_tensor[:,1])
  323. metrics_tensor[METRICS_BFP_NDX, start_ndx:end_ndx] = intersectionSum(1 - label_tensor[:,1], bPred_tensor[:,1])
  324. metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_tensor
  325. metrics_tensor[METRICS_BEN_LOSS_NDX, start_ndx:end_ndx] = benLoss_tensor
  326. metrics_tensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = malLoss_tensor
  327. # metrics_tensor[METRICS_FLG_LOSS_NDX, start_ndx:end_ndx] = flgLoss_tensor
  328. # lungLoss_devtensor = self.diceLoss(label_devtensor[:,2], prediction_devtensor[:,2])
  329. # lungLoss_tensor = lungLoss_devtensor.to('cpu').unsqueeze(1)
  330. # metrics_tensor[METRICS_LUNG_LOSS_NDX, start_ndx:end_ndx] = lungLoss_tensor
  331. # TODO: replace with torch.autograd.detect_anomaly
  332. # assert np.isfinite(metrics_tensor).all()
  333. # return nn.MSELoss()(prediction_devtensor, label_devtensor)
  334. return malLoss_devtensor.mean() + benLoss_devtensor.mean()
  335. # return self.diceLoss(label_devtensor[:,0], prediction_devtensor[:,0]).mean()
  336. def diceLoss(self, label_devtensor, prediction_devtensor, epsilon=0.01, p=False):
  337. # sum2 = lambda t: t.sum([1,2,3,4])
  338. sum2 = lambda t: t.view(t.size(0), -1).sum(dim=1)
  339. # max2 = lambda t: t.view(t.size(0), -1).max(dim=1)[0]
  340. diceCorrect_devtensor = sum2(prediction_devtensor * label_devtensor)
  341. dicePrediction_devtensor = sum2(prediction_devtensor)
  342. diceLabel_devtensor = sum2(label_devtensor)
  343. epsilon_devtensor = torch.ones_like(diceCorrect_devtensor) * epsilon
  344. diceLoss_devtensor = 1 - (2 * diceCorrect_devtensor + epsilon_devtensor) / (dicePrediction_devtensor + diceLabel_devtensor + epsilon_devtensor)
  345. if not torch.isfinite(diceLoss_devtensor).all():
  346. log.debug('')
  347. log.debug('diceLoss_devtensor')
  348. log.debug(diceLoss_devtensor.to('cpu'))
  349. log.debug('diceCorrect_devtensor')
  350. log.debug(diceCorrect_devtensor.to('cpu'))
  351. log.debug('dicePrediction_devtensor')
  352. log.debug(dicePrediction_devtensor.to('cpu'))
  353. log.debug('diceLabel_devtensor')
  354. log.debug(diceLabel_devtensor.to('cpu'))
  355. return diceLoss_devtensor
  356. def logImages(self, epoch_ndx, train_dl, test_dl):
  357. for mode_str, dl in [('trn', train_dl), ('tst', test_dl)]:
  358. for i, series_uid in enumerate(sorted(dl.dataset.series_list)[:12]):
  359. ct = getCt(series_uid)
  360. noduleInfo_tup = (ct.malignantInfo_list or ct.benignInfo_list)[0]
  361. center_irc = xyz2irc(noduleInfo_tup.center_xyz, ct.origin_xyz, ct.vxSize_xyz, ct.direction_tup)
  362. sample_tup = dl.dataset[(series_uid, int(center_irc.index))]
  363. input_tensor = sample_tup[0].unsqueeze(0)
  364. label_tensor = sample_tup[1].unsqueeze(0)
  365. input_devtensor = input_tensor.to(self.device)
  366. label_devtensor = label_tensor.to(self.device)
  367. prediction_devtensor = self.model(input_devtensor)
  368. prediction_ary = prediction_devtensor.to('cpu').detach().numpy()
  369. image_ary = np.zeros((512, 512, 3), dtype=np.float32)
  370. image_ary[:,:,:] = (input_tensor[0,2].numpy().reshape((512,512,1))) * 0.25
  371. image_ary[:,:,0] += prediction_ary[0,0] * 0.5
  372. image_ary[:,:,1] += prediction_ary[0,1] * 0.25
  373. # image_ary[:,:,2] += prediction_ary[0,2] * 0.5
  374. # log.debug([image_ary.__array_interface__['typestr']])
  375. # image_ary = (image_ary * 255).astype(np.uint8)
  376. # log.debug([image_ary.__array_interface__['typestr']])
  377. writer = getattr(self, mode_str + '_writer')
  378. try:
  379. writer.add_image('{}/{}_pred'.format(mode_str, i), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
  380. except:
  381. log.debug([image_ary.shape, image_ary.dtype])
  382. raise
  383. if epoch_ndx == 1:
  384. label_ary = label_tensor.numpy()
  385. image_ary = np.zeros((512, 512, 3), dtype=np.float32)
  386. image_ary[:,:,:] = (input_tensor[0,2].numpy().reshape((512,512,1))) * 0.25
  387. image_ary[:,:,0] += label_ary[0,0] * 0.5
  388. image_ary[:,:,1] += label_ary[0,1] * 0.25
  389. image_ary[:,:,2] += (input_tensor[0,-1].numpy() - (label_ary[0,0].astype(np.bool) | label_ary[0,1].astype(np.bool))) * 0.25
  390. # log.debug([image_ary.__array_interface__['typestr']])
  391. image_ary = (image_ary * 255).astype(np.uint8)
  392. # log.debug([image_ary.__array_interface__['typestr']])
  393. writer = getattr(self, mode_str + '_writer')
  394. writer.add_image('{}/{}_label'.format(mode_str, i), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
  395. def logPerformanceMetrics(self,
  396. epoch_ndx,
  397. mode_str,
  398. metrics_tensor,
  399. # trainingMetrics_tensor,
  400. # testingMetrics_tensor,
  401. classificationThreshold_float=0.5,
  402. ):
  403. log.info("E{} {}".format(
  404. epoch_ndx,
  405. type(self).__name__,
  406. ))
  407. score = 0.0
  408. # for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
  409. metrics_ary = metrics_tensor.cpu().detach().numpy()
  410. sum_ary = metrics_ary.sum(axis=1)
  411. assert np.isfinite(metrics_ary).all()
  412. malLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 1) | (metrics_ary[METRICS_LABEL_NDX] == 3)
  413. if self.cli_args.segmentation:
  414. benLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 2) | (metrics_ary[METRICS_LABEL_NDX] == 3)
  415. else:
  416. benLabel_mask = ~malLabel_mask
  417. # malFound_mask = metrics_ary[METRICS_MFOUND_NDX] > classificationThreshold_float
  418. # malLabel_mask = ~benLabel_mask
  419. # malPred_mask = ~benPred_mask
  420. benLabel_count = sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]
  421. malLabel_count = sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]
  422. trueNeg_count = benCorrect_count = sum_ary[METRICS_BTP_NDX]
  423. truePos_count = malCorrect_count = sum_ary[METRICS_MTP_NDX]
  424. #
  425. # falsePos_count = benLabel_count - benCorrect_count
  426. # falseNeg_count = malLabel_count - malCorrect_count
  427. metrics_dict = {}
  428. metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
  429. # metrics_dict['loss/msk'] = metrics_ary[METRICS_MASKLOSS_NDX].mean()
  430. # metrics_dict['loss/mal'] = metrics_ary[METRICS_MALLOSS_NDX].mean()
  431. # metrics_dict['loss/lng'] = metrics_ary[METRICS_LUNG_LOSS_NDX, benLabel_mask].mean()
  432. metrics_dict['loss/mal'] = np.nan_to_num(metrics_ary[METRICS_MAL_LOSS_NDX, malLabel_mask].mean())
  433. metrics_dict['loss/ben'] = metrics_ary[METRICS_BEN_LOSS_NDX, benLabel_mask].mean()
  434. # metrics_dict['loss/flg'] = metrics_ary[METRICS_FLG_LOSS_NDX].mean()
  435. # metrics_dict['flagged/all'] = sum_ary[METRICS_MOK_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
  436. # metrics_dict['flagged/slices'] = (malLabel_mask & malFound_mask).sum() / malLabel_mask.sum() * 100
  437. metrics_dict['correct/mal'] = sum_ary[METRICS_MTP_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
  438. metrics_dict['correct/ben'] = sum_ary[METRICS_BTP_NDX] / (sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]) * 100
  439. precision = metrics_dict['pr/precision'] = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFP_NDX]) or 1)
  440. recall = metrics_dict['pr/recall'] = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) or 1)
  441. metrics_dict['pr/f1_score'] = 2 * (precision * recall) / ((precision + recall) or 1)
  442. log.info(("E{} {:8} "
  443. + "{loss/all:.4f} loss, "
  444. # + "{loss/flg:.4f} flagged loss, "
  445. # + "{flagged/all:-5.1f}% pixels flagged, "
  446. # + "{flagged/slices:-5.1f}% slices flagged, "
  447. + "{pr/precision:.4f} precision, "
  448. + "{pr/recall:.4f} recall, "
  449. + "{pr/f1_score:.4f} f1 score"
  450. ).format(
  451. epoch_ndx,
  452. mode_str,
  453. **metrics_dict,
  454. ))
  455. log.info(("E{} {:8} "
  456. + "{loss/mal:.4f} loss, "
  457. + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
  458. ).format(
  459. epoch_ndx,
  460. mode_str + '_mal',
  461. malCorrect_count=malCorrect_count,
  462. malLabel_count=malLabel_count,
  463. **metrics_dict,
  464. ))
  465. log.info(("E{} {:8} "
  466. + "{loss/ben:.4f} loss, "
  467. + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
  468. ).format(
  469. epoch_ndx,
  470. mode_str + '_ben',
  471. benCorrect_count=benCorrect_count,
  472. benLabel_count=benLabel_count,
  473. **metrics_dict,
  474. ))
  475. writer = getattr(self, mode_str + '_writer')
  476. prefix_str = 'seg_' if self.cli_args.segmentation else ''
  477. for key, value in metrics_dict.items():
  478. writer.add_scalar(prefix_str + key, value, self.totalTrainingSamples_count)
  479. if not self.cli_args.segmentation:
  480. writer.add_pr_curve(
  481. 'pr',
  482. metrics_ary[METRICS_LABEL_NDX],
  483. metrics_ary[METRICS_PRED_NDX],
  484. self.totalTrainingSamples_count,
  485. )
  486. benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
  487. malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
  488. bins = [x/50.0 for x in range(51)]
  489. writer.add_histogram(
  490. 'is_ben',
  491. metrics_ary[METRICS_PRED_NDX, benHist_mask],
  492. self.totalTrainingSamples_count,
  493. bins=bins,
  494. )
  495. writer.add_histogram(
  496. 'is_mal',
  497. metrics_ary[METRICS_PRED_NDX, malHist_mask],
  498. self.totalTrainingSamples_count,
  499. bins=bins,
  500. )
  501. score = 1 \
  502. + metrics_dict['pr/f1_score'] \
  503. - metrics_dict['loss/mal'] * 0.01 \
  504. - metrics_dict['loss/all'] * 0.0001
  505. return score
  506. def logModelMetrics(self, model):
  507. writer = getattr(self, 'trn_writer')
  508. model = getattr(model, 'module', model)
  509. for name, param in model.named_parameters():
  510. if param.requires_grad:
  511. min_data = float(param.data.min())
  512. max_data = float(param.data.max())
  513. max_extent = max(abs(min_data), abs(max_data))
  514. bins = [x/50*max_extent for x in range(-50, 51)]
  515. writer.add_histogram(
  516. name.rsplit('.', 1)[-1] + '/' + name,
  517. param.data.cpu().numpy(),
  518. # metrics_ary[METRICS_PRED_NDX, benHist_mask],
  519. self.totalTrainingSamples_count,
  520. bins=bins,
  521. )
  522. # print name, param.data
  523. def saveModel(self, type_str, epoch_ndx, isBest=False):
  524. file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, '{}_{}_{}.{}.state'.format(type_str, self.time_str, self.cli_args.comment, self.totalTrainingSamples_count))
  525. os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
  526. model = self.model
  527. if hasattr(model, 'module'):
  528. model = model.module
  529. state = {
  530. 'model_state': model.state_dict(),
  531. 'model_name': type(model).__name__,
  532. 'optimizer_state' : self.optimizer.state_dict(),
  533. 'optimizer_name': type(self.optimizer).__name__,
  534. 'epoch': epoch_ndx,
  535. 'totalTrainingSamples_count': self.totalTrainingSamples_count,
  536. # 'resumed_from': self.cli_args.resume,
  537. }
  538. torch.save(state, file_path)
  539. log.debug("Saved model params to {}".format(file_path))
  540. if isBest:
  541. file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, '{}_{}_{}.{}.state'.format(type_str, self.time_str, self.cli_args.comment, 'best'))
  542. torch.save(state, file_path)
  543. log.debug("Saved model params to {}".format(file_path))
  544. if __name__ == '__main__':
  545. sys.exit(LunaTrainingApp().main() or 0)