training.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702
  1. import argparse
  2. import datetime
  3. import os
  4. import socket
  5. import sys
  6. import numpy as np
  7. from tensorboardX import SummaryWriter
  8. import torch
  9. import torch.nn as nn
  10. import torch.optim
  11. from torch.optim import SGD, Adam
  12. from torch.utils.data import DataLoader
  13. from util.util import enumerateWithEstimate
  14. from .dsets import TrainingLuna2dSegmentationDataset, TestingLuna2dSegmentationDataset, LunaClassificationDataset, getCt
  15. from util.logconf import logging
  16. from util.util import xyz2irc
  17. from .model import UNetWrapper, LunaModel
  18. log = logging.getLogger(__name__)
  19. # log.setLevel(logging.WARN)
  20. # log.setLevel(logging.INFO)
  21. log.setLevel(logging.DEBUG)
  22. # Used for computeClassificationLoss and logMetrics to index into metrics_tensor/metrics_ary
  23. # METRICS_LABEL_NDX=0
  24. # METRICS_PRED_NDX=1
  25. # METRICS_LOSS_NDX=2
  26. # METRICS_MAL_LOSS_NDX=3
  27. # METRICS_BEN_LOSS_NDX=4
  28. # METRICS_LUNG_LOSS_NDX=5
  29. # METRICS_MASKLOSS_NDX=2
  30. # METRICS_MALLOSS_NDX=3
  31. METRICS_LOSS_NDX = 0
  32. METRICS_LABEL_NDX = 1
  33. METRICS_PRED_NDX = 2
  34. METRICS_MTP_NDX = 3
  35. METRICS_MFN_NDX = 4
  36. METRICS_MFP_NDX = 5
  37. METRICS_BTP_NDX = 6
  38. METRICS_BFN_NDX = 7
  39. METRICS_BFP_NDX = 8
  40. METRICS_MAL_LOSS_NDX = 9
  41. METRICS_BEN_LOSS_NDX = 10
  42. # METRICS_MFOUND_NDX = 2
  43. # METRICS_MOK_NDX = 2
  44. # METRICS_FLG_LOSS_NDX = 10
  45. METRICS_SIZE = 11
  46. class LunaTrainingApp(object):
  47. def __init__(self, sys_argv=None):
  48. if sys_argv is None:
  49. sys_argv = sys.argv[1:]
  50. parser = argparse.ArgumentParser()
  51. parser.add_argument('--batch-size',
  52. help='Batch size to use for training',
  53. default=4,
  54. type=int,
  55. )
  56. parser.add_argument('--num-workers',
  57. help='Number of worker processes for background data loading',
  58. default=8,
  59. type=int,
  60. )
  61. parser.add_argument('--epochs',
  62. help='Number of epochs to train for',
  63. default=1,
  64. type=int,
  65. )
  66. # parser.add_argument('--resume',
  67. # default=None,
  68. # help="File to resume training from.",
  69. # )
  70. parser.add_argument('--segmentation',
  71. help="TODO", # TODO
  72. action='store_true',
  73. default=False,
  74. )
  75. parser.add_argument('--balanced',
  76. help="Balance the training data to half benign, half malignant.",
  77. action='store_true',
  78. default=False,
  79. )
  80. parser.add_argument('--adaptive',
  81. help="Balance the training data to start half benign, half malignant, and end at a 100:1 ratio.",
  82. action='store_true',
  83. default=False,
  84. )
  85. parser.add_argument('--scaled',
  86. help="Scale the CT chunks to square voxels.",
  87. action='store_true',
  88. default=False,
  89. )
  90. parser.add_argument('--multiscaled',
  91. help="Scale the CT chunks to square voxels.",
  92. action='store_true',
  93. default=False,
  94. )
  95. parser.add_argument('--augmented',
  96. help="Augment the training data (implies --scaled).",
  97. action='store_true',
  98. default=False,
  99. )
  100. parser.add_argument('--tb-prefix',
  101. default='p2ch10',
  102. help="Data prefix to use for Tensorboard run. Defaults to chapter.",
  103. )
  104. parser.add_argument('comment',
  105. help="Comment suffix for Tensorboard run.",
  106. nargs='?',
  107. default='none',
  108. )
  109. self.cli_args = parser.parse_args(sys_argv)
  110. self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
  111. self.trn_writer = None
  112. self.tst_writer = None
  113. self.use_cuda = torch.cuda.is_available()
  114. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  115. if socket.gethostname() == 'c2':
  116. self.device = torch.device("cuda:1") # TODO: remove me before print
  117. self.model = self.initModel()
  118. self.optimizer = self.initOptimizer()
  119. self.totalTrainingSamples_count = 0
  120. def initModel(self):
  121. if self.cli_args.segmentation:
  122. model = UNetWrapper(in_channels=8, n_classes=2, depth=5, wf=6, padding=True, batch_norm=True, up_mode='upconv')
  123. else:
  124. model = LunaModel()
  125. if self.use_cuda:
  126. if torch.cuda.device_count() > 1:
  127. if socket.gethostname() == 'c2':
  128. model = nn.DataParallel(model, device_ids=[1, 0]) # TODO: remove me before print
  129. else:
  130. model = nn.DataParallel(model)
  131. model = model.to(self.device)
  132. return model
  133. def initOptimizer(self):
  134. # self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.99)
  135. return Adam(self.model.parameters())
  136. def initTrainDl(self):
  137. if self.cli_args.segmentation:
  138. train_ds = TrainingLuna2dSegmentationDataset(
  139. test_stride=10,
  140. contextSlices_count=3,
  141. )
  142. else:
  143. train_ds = LunaClassificationDataset(
  144. test_stride=10,
  145. isTestSet_bool=False,
  146. # series_uid=None,
  147. # sortby_str='random',
  148. ratio_int=int(self.cli_args.balanced),
  149. # scaled_bool=False,
  150. # multiscaled_bool=False,
  151. # augmented_bool=False,
  152. # noduleInfo_list=None,
  153. )
  154. train_dl = DataLoader(
  155. train_ds,
  156. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  157. num_workers=self.cli_args.num_workers,
  158. pin_memory=self.use_cuda,
  159. )
  160. return train_dl
  161. def initTestDl(self):
  162. if self.cli_args.segmentation:
  163. test_ds = TestingLuna2dSegmentationDataset(
  164. test_stride=10,
  165. contextSlices_count=3,
  166. )
  167. else:
  168. test_ds = LunaClassificationDataset(
  169. test_stride=10,
  170. isTestSet_bool=True,
  171. # series_uid=None,
  172. # sortby_str='random',
  173. # ratio_int=int(self.cli_args.balanced),
  174. # scaled_bool=False,
  175. # multiscaled_bool=False,
  176. # augmented_bool=False,
  177. # noduleInfo_list=None,
  178. )
  179. test_dl = DataLoader(
  180. test_ds,
  181. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  182. num_workers=self.cli_args.num_workers,
  183. pin_memory=self.use_cuda,
  184. )
  185. return test_dl
  186. def initTensorboardWriters(self):
  187. if self.trn_writer is None:
  188. log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
  189. self.trn_writer = SummaryWriter(log_dir=log_dir + '_segtrn_' + self.cli_args.comment)
  190. self.tst_writer = SummaryWriter(log_dir=log_dir + '_segtst_' + self.cli_args.comment)
  191. def main(self):
  192. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  193. train_dl = self.initTrainDl()
  194. test_dl = self.initTestDl()
  195. self.initTensorboardWriters()
  196. self.logModelMetrics(self.model)
  197. best_score = 0.0
  198. for epoch_ndx in range(1, self.cli_args.epochs + 1):
  199. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  200. epoch_ndx,
  201. self.cli_args.epochs,
  202. len(train_dl),
  203. len(test_dl),
  204. self.cli_args.batch_size,
  205. (torch.cuda.device_count() if self.use_cuda else 1),
  206. ))
  207. trainingMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
  208. if epoch_ndx > 0:
  209. self.logPerformanceMetrics(epoch_ndx, 'trn', trainingMetrics_tensor)
  210. self.logModelMetrics(self.model)
  211. if self.cli_args.segmentation:
  212. self.logImages(epoch_ndx, train_dl, test_dl)
  213. testingMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
  214. score = self.logPerformanceMetrics(epoch_ndx, 'tst', testingMetrics_tensor)
  215. best_score = max(score, best_score)
  216. self.saveModel('seg' if self.cli_args.segmentation else 'cls', epoch_ndx, score == best_score)
  217. if hasattr(self, 'trn_writer'):
  218. self.trn_writer.close()
  219. self.tst_writer.close()
  220. def doTraining(self, epoch_ndx, train_dl):
  221. self.model.train()
  222. trainingMetrics_tensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset))
  223. # train_dl.dataset.shuffleSamples()
  224. batch_iter = enumerateWithEstimate(
  225. train_dl,
  226. "E{} Training".format(epoch_ndx),
  227. start_ndx=train_dl.num_workers,
  228. )
  229. for batch_ndx, batch_tup in batch_iter:
  230. self.optimizer.zero_grad()
  231. if self.cli_args.segmentation:
  232. loss_var = self.computeSegmentationLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
  233. else:
  234. loss_var = self.computeClassificationLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
  235. if loss_var is not None:
  236. loss_var.backward()
  237. self.optimizer.step()
  238. del loss_var
  239. self.totalTrainingSamples_count += trainingMetrics_tensor.size(1)
  240. return trainingMetrics_tensor
  241. def doTesting(self, epoch_ndx, test_dl):
  242. with torch.no_grad():
  243. self.model.eval()
  244. testingMetrics_tensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset))
  245. batch_iter = enumerateWithEstimate(
  246. test_dl,
  247. "E{} Testing ".format(epoch_ndx),
  248. start_ndx=test_dl.num_workers,
  249. )
  250. for batch_ndx, batch_tup in batch_iter:
  251. if self.cli_args.segmentation:
  252. self.computeSegmentationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
  253. else:
  254. self.computeClassificationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
  255. return testingMetrics_tensor
  256. def computeClassificationLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
  257. input_tensor, label_tensor, _series_list, _center_list = batch_tup
  258. input_devtensor = input_tensor.to(self.device)
  259. label_devtensor = label_tensor.to(self.device)
  260. prediction_devtensor = self.model(input_devtensor)
  261. loss_devtensor = nn.MSELoss(reduction='none')(prediction_devtensor, label_devtensor)
  262. start_ndx = batch_ndx * batch_size
  263. end_ndx = start_ndx + label_tensor.size(0)
  264. with torch.no_grad():
  265. # log.debug([metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx].shape, label_tensor.shape])
  266. metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_tensor[:,0]
  267. metrics_tensor[METRICS_PRED_NDX, start_ndx:end_ndx] = prediction_devtensor.to('cpu')[:,0]
  268. # metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor.to('cpu')
  269. prediction_tensor = prediction_devtensor.to('cpu', non_blocking=True)
  270. loss_tensor = loss_devtensor.to('cpu', non_blocking=True)[:,0]
  271. malLabel_tensor = (label_tensor > 0.5)[:,0]
  272. benLabel_tensor = ~malLabel_tensor
  273. malPred_tensor = prediction_tensor > 0.5
  274. benPred_tensor = ~malPred_tensor
  275. metrics_tensor[METRICS_MTP_NDX, start_ndx:end_ndx] = (malLabel_tensor * malPred_tensor).sum(dim=1)
  276. metrics_tensor[METRICS_MFN_NDX, start_ndx:end_ndx] = (malLabel_tensor * benPred_tensor).sum(dim=1)
  277. metrics_tensor[METRICS_MFP_NDX, start_ndx:end_ndx] = (benLabel_tensor * malPred_tensor).sum(dim=1)
  278. metrics_tensor[METRICS_BTP_NDX, start_ndx:end_ndx] = (benLabel_tensor * benPred_tensor).sum(dim=1)
  279. metrics_tensor[METRICS_BFN_NDX, start_ndx:end_ndx] = (benLabel_tensor * malPred_tensor).sum(dim=1)
  280. metrics_tensor[METRICS_BFP_NDX, start_ndx:end_ndx] = (malLabel_tensor * benPred_tensor).sum(dim=1)
  281. metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_tensor
  282. metrics_tensor[METRICS_BEN_LOSS_NDX, start_ndx:end_ndx] = loss_tensor * benLabel_tensor.type(torch.float32)
  283. metrics_tensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = loss_tensor * malLabel_tensor.type(torch.float32)
  284. # TODO: replace with torch.autograd.detect_anomaly
  285. # assert np.isfinite(metrics_tensor).all()
  286. return loss_devtensor.mean()
  287. def computeSegmentationLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
  288. input_tensor, label_tensor, _series_list, _start_list = batch_tup
  289. # if label_tensor.max() < 0.5:
  290. # return None
  291. input_devtensor = input_tensor.to(self.device)
  292. label_devtensor = label_tensor.to(self.device)
  293. prediction_devtensor = self.model(input_devtensor)
  294. # assert prediction_devtensor.is_contiguous()
  295. start_ndx = batch_ndx * batch_size
  296. end_ndx = start_ndx + label_tensor.size(0)
  297. max2 = lambda t: t.view(t.size(0), -1).max(dim=1)[0]
  298. intersectionSum = lambda a, b: (a * b.to(torch.float32)).view(a.size(0), -1).sum(dim=1)
  299. diceLoss_devtensor = self.diceLoss(label_devtensor, prediction_devtensor)
  300. malLoss_devtensor = self.diceLoss(label_devtensor[:,0], prediction_devtensor[:,0])
  301. benLoss_devtensor = self.diceLoss(label_devtensor[:,1], prediction_devtensor[:,1])
  302. with torch.no_grad():
  303. bPred_tensor = prediction_devtensor.to('cpu', non_blocking=True)
  304. diceLoss_tensor = diceLoss_devtensor.to('cpu', non_blocking=True)
  305. malLoss_tensor = malLoss_devtensor.to('cpu', non_blocking=True)
  306. benLoss_tensor = benLoss_devtensor.to('cpu', non_blocking=True)
  307. # flgLoss_devtensor = self.diceLoss(label_devtensor[:,0], label_devtensor[:,0] * prediction_devtensor[:,1])
  308. # flgLoss_tensor = flgLoss_devtensor.to('cpu', non_blocking=True)#.unsqueeze(1)
  309. metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = max2(label_tensor[:,0]) + max2(label_tensor[:,1]) * 2
  310. # metrics_tensor[METRICS_MFOUND_NDX, start_ndx:end_ndx] = (max2(label_tensor[:, 0] * bPred_tensor[:, 1].to(torch.float32)) > 0.5)
  311. # metrics_tensor[METRICS_MOK_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0], bPred_tensor[:,1])
  312. bPred_tensor = bPred_tensor > 0.5
  313. metrics_tensor[METRICS_MTP_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0], bPred_tensor[:,0])
  314. metrics_tensor[METRICS_MFN_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0], ~bPred_tensor[:,0])
  315. metrics_tensor[METRICS_MFP_NDX, start_ndx:end_ndx] = intersectionSum(1 - label_tensor[:,0], bPred_tensor[:,0])
  316. metrics_tensor[METRICS_BTP_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,1], bPred_tensor[:,1])
  317. metrics_tensor[METRICS_BFN_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,1], ~bPred_tensor[:,1])
  318. metrics_tensor[METRICS_BFP_NDX, start_ndx:end_ndx] = intersectionSum(1 - label_tensor[:,1], bPred_tensor[:,1])
  319. metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_tensor
  320. metrics_tensor[METRICS_BEN_LOSS_NDX, start_ndx:end_ndx] = benLoss_tensor
  321. metrics_tensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = malLoss_tensor
  322. # metrics_tensor[METRICS_FLG_LOSS_NDX, start_ndx:end_ndx] = flgLoss_tensor
  323. # lungLoss_devtensor = self.diceLoss(label_devtensor[:,2], prediction_devtensor[:,2])
  324. # lungLoss_tensor = lungLoss_devtensor.to('cpu').unsqueeze(1)
  325. # metrics_tensor[METRICS_LUNG_LOSS_NDX, start_ndx:end_ndx] = lungLoss_tensor
  326. # TODO: replace with torch.autograd.detect_anomaly
  327. # assert np.isfinite(metrics_tensor).all()
  328. # return nn.MSELoss()(prediction_devtensor, label_devtensor)
  329. return malLoss_devtensor.mean() + benLoss_devtensor.mean()
  330. # return self.diceLoss(label_devtensor[:,0], prediction_devtensor[:,0]).mean()
  331. def diceLoss(self, label_devtensor, prediction_devtensor, epsilon=0.01, p=False):
  332. # sum2 = lambda t: t.sum([1,2,3,4])
  333. sum2 = lambda t: t.view(t.size(0), -1).sum(dim=1)
  334. # max2 = lambda t: t.view(t.size(0), -1).max(dim=1)[0]
  335. diceCorrect_devtensor = sum2(prediction_devtensor * label_devtensor)
  336. dicePrediction_devtensor = sum2(prediction_devtensor)
  337. diceLabel_devtensor = sum2(label_devtensor)
  338. epsilon_devtensor = torch.ones_like(diceCorrect_devtensor) * epsilon
  339. diceLoss_devtensor = 1 - (2 * diceCorrect_devtensor + epsilon_devtensor) / (dicePrediction_devtensor + diceLabel_devtensor + epsilon_devtensor)
  340. if not torch.isfinite(diceLoss_devtensor).all():
  341. log.debug('')
  342. log.debug('diceLoss_devtensor')
  343. log.debug(diceLoss_devtensor.to('cpu'))
  344. log.debug('diceCorrect_devtensor')
  345. log.debug(diceCorrect_devtensor.to('cpu'))
  346. log.debug('dicePrediction_devtensor')
  347. log.debug(dicePrediction_devtensor.to('cpu'))
  348. log.debug('diceLabel_devtensor')
  349. log.debug(diceLabel_devtensor.to('cpu'))
  350. return diceLoss_devtensor
  351. def logImages(self, epoch_ndx, train_dl, test_dl):
  352. for mode_str, dl in [('trn', train_dl), ('tst', test_dl)]:
  353. for i, series_uid in enumerate(sorted(dl.dataset.series_list)[:12]):
  354. ct = getCt(series_uid)
  355. noduleInfo_tup = (ct.malignantInfo_list or ct.benignInfo_list)[0]
  356. center_irc = xyz2irc(noduleInfo_tup.center_xyz, ct.origin_xyz, ct.vxSize_xyz, ct.direction_tup)
  357. sample_tup = dl.dataset[(series_uid, int(center_irc.index))]
  358. input_tensor = sample_tup[0].unsqueeze(0)
  359. label_tensor = sample_tup[1].unsqueeze(0)
  360. input_devtensor = input_tensor.to(self.device)
  361. label_devtensor = label_tensor.to(self.device)
  362. prediction_devtensor = self.model(input_devtensor)
  363. prediction_ary = prediction_devtensor.to('cpu').detach().numpy()
  364. image_ary = np.zeros((512, 512, 3), dtype=np.float32)
  365. image_ary[:,:,:] = (input_tensor[0,2].numpy().reshape((512,512,1))) * 0.25
  366. image_ary[:,:,0] += prediction_ary[0,0] * 0.5
  367. image_ary[:,:,1] += prediction_ary[0,1] * 0.25
  368. # image_ary[:,:,2] += prediction_ary[0,2] * 0.5
  369. # log.debug([image_ary.__array_interface__['typestr']])
  370. # image_ary = (image_ary * 255).astype(np.uint8)
  371. # log.debug([image_ary.__array_interface__['typestr']])
  372. writer = getattr(self, mode_str + '_writer')
  373. try:
  374. writer.add_image('{}/{}_pred'.format(mode_str, i), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
  375. except:
  376. log.debug([image_ary.shape, image_ary.dtype])
  377. raise
  378. if epoch_ndx == 1:
  379. label_ary = label_tensor.numpy()
  380. image_ary = np.zeros((512, 512, 3), dtype=np.float32)
  381. image_ary[:,:,:] = (input_tensor[0,2].numpy().reshape((512,512,1))) * 0.25
  382. image_ary[:,:,0] += label_ary[0,0] * 0.5
  383. image_ary[:,:,1] += label_ary[0,1] * 0.25
  384. image_ary[:,:,2] += (input_tensor[0,-1].numpy() - (label_ary[0,0].astype(np.bool) | label_ary[0,1].astype(np.bool))) * 0.25
  385. # log.debug([image_ary.__array_interface__['typestr']])
  386. image_ary = (image_ary * 255).astype(np.uint8)
  387. # log.debug([image_ary.__array_interface__['typestr']])
  388. writer = getattr(self, mode_str + '_writer')
  389. writer.add_image('{}/{}_label'.format(mode_str, i), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
  390. def logPerformanceMetrics(self,
  391. epoch_ndx,
  392. mode_str,
  393. metrics_tensor,
  394. # trainingMetrics_tensor,
  395. # testingMetrics_tensor,
  396. classificationThreshold_float=0.5,
  397. ):
  398. log.info("E{} {}".format(
  399. epoch_ndx,
  400. type(self).__name__,
  401. ))
  402. score = 0.0
  403. # for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
  404. metrics_ary = metrics_tensor.cpu().detach().numpy()
  405. sum_ary = metrics_ary.sum(axis=1)
  406. assert np.isfinite(metrics_ary).all()
  407. malLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 1) | (metrics_ary[METRICS_LABEL_NDX] == 3)
  408. if self.cli_args.segmentation:
  409. benLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 2) | (metrics_ary[METRICS_LABEL_NDX] == 3)
  410. else:
  411. benLabel_mask = ~malLabel_mask
  412. # malFound_mask = metrics_ary[METRICS_MFOUND_NDX] > classificationThreshold_float
  413. # malLabel_mask = ~benLabel_mask
  414. # malPred_mask = ~benPred_mask
  415. benLabel_count = sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]
  416. malLabel_count = sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]
  417. trueNeg_count = benCorrect_count = sum_ary[METRICS_BTP_NDX]
  418. truePos_count = malCorrect_count = sum_ary[METRICS_MTP_NDX]
  419. #
  420. # falsePos_count = benLabel_count - benCorrect_count
  421. # falseNeg_count = malLabel_count - malCorrect_count
  422. metrics_dict = {}
  423. metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
  424. # metrics_dict['loss/msk'] = metrics_ary[METRICS_MASKLOSS_NDX].mean()
  425. # metrics_dict['loss/mal'] = metrics_ary[METRICS_MALLOSS_NDX].mean()
  426. # metrics_dict['loss/lng'] = metrics_ary[METRICS_LUNG_LOSS_NDX, benLabel_mask].mean()
  427. metrics_dict['loss/mal'] = np.nan_to_num(metrics_ary[METRICS_MAL_LOSS_NDX, malLabel_mask].mean())
  428. metrics_dict['loss/ben'] = metrics_ary[METRICS_BEN_LOSS_NDX, benLabel_mask].mean()
  429. # metrics_dict['loss/flg'] = metrics_ary[METRICS_FLG_LOSS_NDX].mean()
  430. # metrics_dict['flagged/all'] = sum_ary[METRICS_MOK_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
  431. # metrics_dict['flagged/slices'] = (malLabel_mask & malFound_mask).sum() / malLabel_mask.sum() * 100
  432. metrics_dict['correct/mal'] = sum_ary[METRICS_MTP_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
  433. metrics_dict['correct/ben'] = sum_ary[METRICS_BTP_NDX] / (sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]) * 100
  434. precision = metrics_dict['pr/precision'] = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFP_NDX]) or 1)
  435. recall = metrics_dict['pr/recall'] = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) or 1)
  436. metrics_dict['pr/f1_score'] = 2 * (precision * recall) / ((precision + recall) or 1)
  437. log.info(("E{} {:8} "
  438. + "{loss/all:.4f} loss, "
  439. # + "{loss/flg:.4f} flagged loss, "
  440. # + "{flagged/all:-5.1f}% pixels flagged, "
  441. # + "{flagged/slices:-5.1f}% slices flagged, "
  442. + "{pr/precision:.4f} precision, "
  443. + "{pr/recall:.4f} recall, "
  444. + "{pr/f1_score:.4f} f1 score"
  445. ).format(
  446. epoch_ndx,
  447. mode_str,
  448. **metrics_dict,
  449. ))
  450. log.info(("E{} {:8} "
  451. + "{loss/mal:.4f} loss, "
  452. + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
  453. ).format(
  454. epoch_ndx,
  455. mode_str + '_mal',
  456. malCorrect_count=malCorrect_count,
  457. malLabel_count=malLabel_count,
  458. **metrics_dict,
  459. ))
  460. log.info(("E{} {:8} "
  461. + "{loss/ben:.4f} loss, "
  462. + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
  463. ).format(
  464. epoch_ndx,
  465. mode_str + '_ben',
  466. benCorrect_count=benCorrect_count,
  467. benLabel_count=benLabel_count,
  468. **metrics_dict,
  469. ))
  470. writer = getattr(self, mode_str + '_writer')
  471. for key, value in metrics_dict.items():
  472. writer.add_scalar('seg_' + key, value, self.totalTrainingSamples_count)
  473. if not self.cli_args.segmentation:
  474. writer.add_pr_curve(
  475. 'pr',
  476. metrics_ary[METRICS_LABEL_NDX],
  477. metrics_ary[METRICS_PRED_NDX],
  478. self.totalTrainingSamples_count,
  479. )
  480. benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
  481. malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
  482. bins = [x/50.0 for x in range(51)]
  483. writer.add_histogram(
  484. 'is_ben',
  485. metrics_ary[METRICS_PRED_NDX, benHist_mask],
  486. self.totalTrainingSamples_count,
  487. bins=bins,
  488. )
  489. writer.add_histogram(
  490. 'is_mal',
  491. metrics_ary[METRICS_PRED_NDX, malHist_mask],
  492. self.totalTrainingSamples_count,
  493. bins=bins,
  494. )
  495. score = 1 \
  496. + metrics_dict['pr/f1_score'] \
  497. - metrics_dict['loss/mal'] * 0.01 \
  498. - metrics_dict['loss/all'] * 0.0001
  499. return score
  500. def logModelMetrics(self, model):
  501. writer = getattr(self, 'trn_writer')
  502. model = getattr(model, 'module', model)
  503. for name, param in model.named_parameters():
  504. if param.requires_grad:
  505. min_data = float(param.data.min())
  506. max_data = float(param.data.max())
  507. max_extent = max(abs(min_data), abs(max_data))
  508. bins = [x/50*max_extent for x in range(-50, 51)]
  509. writer.add_histogram(
  510. name.rsplit('.', 1)[-1] + '/' + name,
  511. param.data.cpu().numpy(),
  512. # metrics_ary[METRICS_PRED_NDX, benHist_mask],
  513. self.totalTrainingSamples_count,
  514. bins=bins,
  515. )
  516. # print name, param.data
  517. def saveModel(self, type_str, epoch_ndx, isBest=False):
  518. file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, '{}_{}_{}.{}.state'.format(type_str, self.time_str, self.cli_args.comment, self.totalTrainingSamples_count))
  519. os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
  520. model = self.model
  521. if hasattr(model, 'module'):
  522. model = model.module
  523. state = {
  524. 'model_state': model.state_dict(),
  525. 'model_name': type(model).__name__,
  526. 'optimizer_state' : self.optimizer.state_dict(),
  527. 'optimizer_name': type(self.optimizer).__name__,
  528. 'epoch': epoch_ndx,
  529. 'totalTrainingSamples_count': self.totalTrainingSamples_count,
  530. # 'resumed_from': self.cli_args.resume,
  531. }
  532. torch.save(state, file_path)
  533. log.debug("Saved model params to {}".format(file_path))
  534. if isBest:
  535. file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, '{}_{}_{}.{}.state'.format(type_str, self.time_str, self.cli_args.comment, 'best'))
  536. torch.save(state, file_path)
  537. log.debug("Saved model params to {}".format(file_path))
  538. if __name__ == '__main__':
  539. sys.exit(LunaTrainingApp().main() or 0)