train_seg.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. import argparse
  2. import datetime
  3. import os
  4. import socket
  5. import sys
  6. import numpy as np
  7. from tensorboardX import SummaryWriter
  8. import torch
  9. import torch.nn as nn
  10. import torch.optim
  11. from torch.optim import SGD, Adam
  12. from torch.utils.data import DataLoader
  13. from util.util import enumerateWithEstimate
  14. from .dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt
  15. from util.logconf import logging
  16. from util.util import xyz2irc
  17. from .model_seg import UNetWrapper
  18. log = logging.getLogger(__name__)
  19. # log.setLevel(logging.WARN)
  20. # log.setLevel(logging.INFO)
  21. log.setLevel(logging.DEBUG)
  22. # Used for computeClassificationLoss and logMetrics to index into metrics_tensor/metrics_ary
  23. METRICS_LABEL_NDX = 0
  24. METRICS_LOSS_NDX = 1
  25. METRICS_MAL_LOSS_NDX = 2
  26. # METRICS_ALL_LOSS_NDX = 3
  27. METRICS_MTP_NDX = 4
  28. METRICS_MFN_NDX = 5
  29. METRICS_MFP_NDX = 6
  30. METRICS_ATP_NDX = 7
  31. METRICS_AFN_NDX = 8
  32. METRICS_AFP_NDX = 9
  33. METRICS_SIZE = 10
  34. class LunaTrainingApp(object):
  35. def __init__(self, sys_argv=None):
  36. if sys_argv is None:
  37. sys_argv = sys.argv[1:]
  38. parser = argparse.ArgumentParser()
  39. parser.add_argument('--batch-size',
  40. help='Batch size to use for training',
  41. default=16,
  42. type=int,
  43. )
  44. parser.add_argument('--num-workers',
  45. help='Number of worker processes for background data loading',
  46. default=8,
  47. type=int,
  48. )
  49. parser.add_argument('--epochs',
  50. help='Number of epochs to train for',
  51. default=1,
  52. type=int,
  53. )
  54. parser.add_argument('--augmented',
  55. help="Augment the training data.",
  56. action='store_true',
  57. default=False,
  58. )
  59. parser.add_argument('--augment-flip',
  60. help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
  61. action='store_true',
  62. default=False,
  63. )
  64. # parser.add_argument('--augment-offset',
  65. # help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
  66. # action='store_true',
  67. # default=False,
  68. # )
  69. # parser.add_argument('--augment-scale',
  70. # help="Augment the training data by randomly increasing or decreasing the size of the nodule.",
  71. # action='store_true',
  72. # default=False,
  73. # )
  74. parser.add_argument('--augment-rotate',
  75. help="Augment the training data by randomly rotating the data around the head-foot axis.",
  76. action='store_true',
  77. default=False,
  78. )
  79. parser.add_argument('--augment-noise',
  80. help="Augment the training data by randomly adding noise to the data.",
  81. action='store_true',
  82. default=False,
  83. )
  84. parser.add_argument('--tb-prefix',
  85. default='p2ch12',
  86. help="Data prefix to use for Tensorboard run. Defaults to chapter.",
  87. )
  88. parser.add_argument('comment',
  89. help="Comment suffix for Tensorboard run.",
  90. nargs='?',
  91. default='none',
  92. )
  93. self.cli_args = parser.parse_args(sys_argv)
  94. self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
  95. self.totalTrainingSamples_count = 0
  96. self.trn_writer = None
  97. self.tst_writer = None
  98. augmentation_dict = {}
  99. if self.cli_args.augmented or self.cli_args.augment_flip:
  100. augmentation_dict['flip'] = True
  101. if self.cli_args.augmented or self.cli_args.augment_rotate:
  102. augmentation_dict['rotate'] = True
  103. if self.cli_args.augmented or self.cli_args.augment_noise:
  104. augmentation_dict['noise'] = 25.0
  105. self.augmentation_dict = augmentation_dict
  106. self.use_cuda = torch.cuda.is_available()
  107. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  108. self.model = self.initModel()
  109. self.optimizer = self.initOptimizer()
  110. def initModel(self):
  111. model = UNetWrapper(in_channels=8, n_classes=1, depth=4, wf=3, padding=True, batch_norm=True, up_mode='upconv')
  112. if self.use_cuda:
  113. if torch.cuda.device_count() > 1:
  114. model = nn.DataParallel(model)
  115. model = model.to(self.device)
  116. return model
  117. def initOptimizer(self):
  118. return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
  119. # return Adam(self.model.parameters())
  120. def initTrainDl(self):
  121. train_ds = TrainingLuna2dSegmentationDataset(
  122. test_stride=10,
  123. isTestSet_bool=False,
  124. contextSlices_count=3,
  125. augmentation_dict=self.augmentation_dict,
  126. )
  127. train_dl = DataLoader(
  128. train_ds,
  129. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  130. num_workers=self.cli_args.num_workers,
  131. pin_memory=self.use_cuda,
  132. )
  133. return train_dl
  134. def initTestDl(self):
  135. test_ds = Luna2dSegmentationDataset(
  136. test_stride=10,
  137. isTestSet_bool=True,
  138. contextSlices_count=3,
  139. )
  140. test_dl = DataLoader(
  141. test_ds,
  142. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  143. num_workers=self.cli_args.num_workers,
  144. pin_memory=self.use_cuda,
  145. )
  146. return test_dl
  147. def initTensorboardWriters(self):
  148. if self.trn_writer is None:
  149. log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
  150. self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_seg_' + self.cli_args.comment)
  151. self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_seg_' + self.cli_args.comment)
  152. def main(self):
  153. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  154. train_dl = self.initTrainDl()
  155. test_dl = self.initTestDl()
  156. # self.logModelMetrics(self.model)
  157. best_score = 0.0
  158. for epoch_ndx in range(1, self.cli_args.epochs + 1):
  159. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  160. epoch_ndx,
  161. self.cli_args.epochs,
  162. len(train_dl),
  163. len(test_dl),
  164. self.cli_args.batch_size,
  165. (torch.cuda.device_count() if self.use_cuda else 1),
  166. ))
  167. trainingMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
  168. self.logMetrics(epoch_ndx, 'trn', trainingMetrics_tensor)
  169. self.logImages(epoch_ndx, 'trn', train_dl)
  170. self.logImages(epoch_ndx, 'tst', test_dl)
  171. # self.logModelMetrics(self.model)
  172. testingMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
  173. score = self.logMetrics(epoch_ndx, 'tst', testingMetrics_tensor)
  174. best_score = max(score, best_score)
  175. self.saveModel('seg', epoch_ndx, score == best_score)
  176. if hasattr(self, 'trn_writer'):
  177. self.trn_writer.close()
  178. self.tst_writer.close()
  179. def doTraining(self, epoch_ndx, train_dl):
  180. trainingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
  181. self.model.train()
  182. batch_iter = enumerateWithEstimate(
  183. train_dl,
  184. "E{} Training".format(epoch_ndx),
  185. start_ndx=train_dl.num_workers,
  186. )
  187. for batch_ndx, batch_tup in batch_iter:
  188. self.optimizer.zero_grad()
  189. loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_devtensor)
  190. loss_var.backward()
  191. self.optimizer.step()
  192. del loss_var
  193. self.totalTrainingSamples_count += trainingMetrics_devtensor.size(1)
  194. return trainingMetrics_devtensor.to('cpu')
  195. def doTesting(self, epoch_ndx, test_dl):
  196. with torch.no_grad():
  197. testingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset)).to(self.device)
  198. self.model.eval()
  199. batch_iter = enumerateWithEstimate(
  200. test_dl,
  201. "E{} Testing ".format(epoch_ndx),
  202. start_ndx=test_dl.num_workers,
  203. )
  204. for batch_ndx, batch_tup in batch_iter:
  205. self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_devtensor)
  206. return testingMetrics_devtensor.to('cpu')
  207. def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_devtensor):
  208. input_tensor, label_tensor, label_list, ben_tensor, mal_tensor, _series_list, _start_list = batch_tup
  209. input_devtensor = input_tensor.to(self.device, non_blocking=True)
  210. label_devtensor = label_tensor.to(self.device, non_blocking=True)
  211. mal_devtensor = mal_tensor.to(self.device, non_blocking=True)
  212. ben_devtensor = ben_tensor.to(self.device, non_blocking=True)
  213. start_ndx = batch_ndx * batch_size
  214. end_ndx = start_ndx + label_tensor.size(0)
  215. intersectionSum = lambda a, b: (a * b).view(a.size(0), -1).sum(dim=1)
  216. prediction_devtensor = self.model(input_devtensor)
  217. diceLoss_devtensor = self.diceLoss(label_devtensor, prediction_devtensor)
  218. with torch.no_grad():
  219. predictionBool_devtensor = (prediction_devtensor > 0.5).to(torch.float32)
  220. metrics_devtensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_list
  221. metrics_devtensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_devtensor
  222. # benPred_devtensor = predictionBool_devtensor * (1 - mal_devtensor)
  223. tp = intersectionSum( label_devtensor, predictionBool_devtensor)
  224. fn = intersectionSum( label_devtensor, 1 - predictionBool_devtensor)
  225. fp = intersectionSum(1 - label_devtensor, predictionBool_devtensor)
  226. # ls = self.diceLoss(label_devtensor, benPred_devtensor)
  227. metrics_devtensor[METRICS_ATP_NDX, start_ndx:end_ndx] = tp
  228. metrics_devtensor[METRICS_AFN_NDX, start_ndx:end_ndx] = fn
  229. metrics_devtensor[METRICS_AFP_NDX, start_ndx:end_ndx] = fp
  230. # metrics_devtensor[METRICS_ALL_LOSS_NDX, start_ndx:end_ndx] = ls
  231. del tp, fn, fp
  232. malPred_devtensor = predictionBool_devtensor * (1 - ben_devtensor)
  233. tp = intersectionSum( mal_devtensor, malPred_devtensor)
  234. fn = intersectionSum( mal_devtensor, 1 - malPred_devtensor)
  235. fp = intersectionSum(1 - label_devtensor, malPred_devtensor)
  236. ls = self.diceLoss(mal_devtensor, malPred_devtensor)
  237. metrics_devtensor[METRICS_MTP_NDX, start_ndx:end_ndx] = tp
  238. metrics_devtensor[METRICS_MFN_NDX, start_ndx:end_ndx] = fn
  239. # metrics_devtensor[METRICS_MFP_NDX, start_ndx:end_ndx] = fp
  240. metrics_devtensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = ls
  241. del malPred_devtensor, tp, fn, fp, ls
  242. return diceLoss_devtensor.mean()
  243. # def diceLoss(self, label_devtensor, prediction_devtensor, epsilon=0.01, p=False):
  244. def diceLoss(self, label_devtensor, prediction_devtensor, epsilon=1024, p=False):
  245. sum_dim1 = lambda t: t.view(t.size(0), -1).sum(dim=1)
  246. diceLabel_devtensor = sum_dim1(label_devtensor)
  247. dicePrediction_devtensor = sum_dim1(prediction_devtensor)
  248. diceCorrect_devtensor = sum_dim1(prediction_devtensor * label_devtensor)
  249. epsilon_devtensor = torch.ones_like(diceCorrect_devtensor) * epsilon
  250. diceLoss_devtensor = 1 - (2 * diceCorrect_devtensor + epsilon_devtensor) / (dicePrediction_devtensor + diceLabel_devtensor + epsilon_devtensor)
  251. if p:
  252. log.debug([])
  253. log.debug(['diceCorrect_devtensor ', diceCorrect_devtensor[0].item()])
  254. log.debug(['dicePrediction_devtensor', dicePrediction_devtensor[0].item()])
  255. log.debug(['diceLabel_devtensor ', diceLabel_devtensor[0].item()])
  256. log.debug(['2*diceCorrect_devtensor ', 2 * diceCorrect_devtensor[0].item()])
  257. log.debug(['Prediction + Label ', dicePrediction_devtensor[0].item() + diceLabel_devtensor[0].item()])
  258. log.debug(['diceLoss_devtensor ', diceLoss_devtensor[0].item()])
  259. return diceLoss_devtensor
  260. def logImages(self, epoch_ndx, mode_str, dl):
  261. for i, series_uid in enumerate(sorted(dl.dataset.series_list)[:12]):
  262. ct = getCt(series_uid)
  263. for slice_ndx in range(0, ct.ary.shape[0], ct.ary.shape[0] // 5):
  264. sample_tup = dl.dataset[(series_uid, slice_ndx, False)]
  265. ct_tensor, nodule_tensor, label_int, ben_tensor, mal_tensor, series_uid, ct_ndx = sample_tup
  266. ct_tensor[:-1,:,:] += 1000
  267. ct_tensor[:-1,:,:] /= 2000
  268. input_devtensor = ct_tensor.to(self.device)
  269. label_devtensor = nodule_tensor.to(self.device)
  270. prediction_devtensor = self.model(input_devtensor.unsqueeze(0))[0]
  271. prediction_ary = prediction_devtensor.to('cpu').detach().numpy()
  272. label_ary = nodule_tensor.numpy()
  273. ben_ary = ben_tensor.numpy()
  274. mal_ary = mal_tensor.numpy()
  275. image_ary = np.zeros((512, 512, 3), dtype=np.float32)
  276. image_ary[:,:,:] = (ct_tensor[dl.dataset.contextSlices_count].numpy().reshape((512,512,1)))
  277. image_ary[:,:,0] += prediction_ary[0] * (1 - label_ary[0]) # Red
  278. image_ary[:,:,1] += prediction_ary[0] * mal_ary # Green
  279. image_ary[:,:,2] += prediction_ary[0] * ben_ary # Blue
  280. writer = getattr(self, mode_str + '_writer')
  281. image_ary *= 0.5
  282. image_ary[image_ary < 0] = 0
  283. image_ary[image_ary > 1] = 1
  284. writer.add_image('{}/{}_prediction_{}'.format(mode_str, i, slice_ndx), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
  285. # self.diceLoss(label_devtensor, prediction_devtensor, p=True)
  286. if epoch_ndx == 1:
  287. image_ary = np.zeros((512, 512, 3), dtype=np.float32)
  288. image_ary[:,:,:] = (ct_tensor[dl.dataset.contextSlices_count].numpy().reshape((512,512,1)))
  289. image_ary[:,:,0] += (1 - label_ary[0]) * ct_tensor[-1].numpy() # Red
  290. image_ary[:,:,1] += mal_ary # Green
  291. image_ary[:,:,2] += ben_ary # Blue
  292. writer = getattr(self, mode_str + '_writer')
  293. image_ary *= 0.5
  294. image_ary[image_ary < 0] = 0
  295. image_ary[image_ary > 1] = 1
  296. writer.add_image('{}/{}_label_{}'.format(mode_str, i, slice_ndx), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
  297. def logMetrics(self,
  298. epoch_ndx,
  299. mode_str,
  300. metrics_tensor,
  301. ):
  302. self.initTensorboardWriters()
  303. log.info("E{} {}".format(
  304. epoch_ndx,
  305. type(self).__name__,
  306. ))
  307. metrics_ary = metrics_tensor.cpu().detach().numpy()
  308. sum_ary = metrics_ary.sum(axis=1)
  309. assert np.isfinite(metrics_ary).all()
  310. malLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 1) | (metrics_ary[METRICS_LABEL_NDX] == 3)
  311. # allLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 2) | (metrics_ary[METRICS_LABEL_NDX] == 3)
  312. allLabel_count = sum_ary[METRICS_ATP_NDX] + sum_ary[METRICS_AFN_NDX]
  313. malLabel_count = sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]
  314. allCorrect_count = sum_ary[METRICS_ATP_NDX]
  315. malCorrect_count = sum_ary[METRICS_MTP_NDX]
  316. #
  317. # falsePos_count = allLabel_count - allCorrect_count
  318. # falseNeg_count = malLabel_count - malCorrect_count
  319. metrics_dict = {}
  320. metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
  321. metrics_dict['loss/mal'] = np.nan_to_num(metrics_ary[METRICS_MAL_LOSS_NDX, malLabel_mask].mean())
  322. # metrics_dict['loss/all'] = metrics_ary[METRICS_ALL_LOSS_NDX, allLabel_mask].mean()
  323. metrics_dict['correct/mal'] = sum_ary[METRICS_MTP_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
  324. metrics_dict['correct/all'] = sum_ary[METRICS_ATP_NDX] / (sum_ary[METRICS_ATP_NDX] + sum_ary[METRICS_AFN_NDX]) * 100
  325. precision = metrics_dict['pr/precision'] = sum_ary[METRICS_ATP_NDX] / ((sum_ary[METRICS_ATP_NDX] + sum_ary[METRICS_AFP_NDX]) or 1)
  326. recall = metrics_dict['pr/recall'] = sum_ary[METRICS_ATP_NDX] / ((sum_ary[METRICS_ATP_NDX] + sum_ary[METRICS_AFN_NDX]) or 1)
  327. metrics_dict['pr/f1_score'] = 2 * (precision * recall) / ((precision + recall) or 1)
  328. log.info(("E{} {:8} "
  329. + "{loss/all:.4f} loss, "
  330. + "{pr/precision:.4f} precision, "
  331. + "{pr/recall:.4f} recall, "
  332. + "{pr/f1_score:.4f} f1 score"
  333. ).format(
  334. epoch_ndx,
  335. mode_str,
  336. **metrics_dict,
  337. ))
  338. log.info(("E{} {:8} "
  339. + "{loss/all:.4f} loss, "
  340. + "{correct/all:-5.1f}% correct ({allCorrect_count:} of {allLabel_count:})"
  341. ).format(
  342. epoch_ndx,
  343. mode_str + '_all',
  344. allCorrect_count=allCorrect_count,
  345. allLabel_count=allLabel_count,
  346. **metrics_dict,
  347. ))
  348. log.info(("E{} {:8} "
  349. + "{loss/mal:.4f} loss, "
  350. + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
  351. ).format(
  352. epoch_ndx,
  353. mode_str + '_mal',
  354. malCorrect_count=malCorrect_count,
  355. malLabel_count=malLabel_count,
  356. **metrics_dict,
  357. ))
  358. writer = getattr(self, mode_str + '_writer')
  359. prefix_str = 'seg_'
  360. for key, value in metrics_dict.items():
  361. writer.add_scalar(prefix_str + key, value, self.totalTrainingSamples_count)
  362. score = 1 \
  363. + metrics_dict['pr/f1_score'] \
  364. - metrics_dict['pr/recall'] * 0.01 \
  365. - metrics_dict['loss/mal'] * 0.001 \
  366. - metrics_dict['loss/all'] * 0.0001
  367. return score
  368. # def logModelMetrics(self, model):
  369. # writer = getattr(self, 'trn_writer')
  370. #
  371. # model = getattr(model, 'module', model)
  372. #
  373. # for name, param in model.named_parameters():
  374. # if param.requires_grad:
  375. # min_data = float(param.data.min())
  376. # max_data = float(param.data.max())
  377. # max_extent = max(abs(min_data), abs(max_data))
  378. #
  379. # # bins = [x/50*max_extent for x in range(-50, 51)]
  380. #
  381. # writer.add_histogram(
  382. # name.rsplit('.', 1)[-1] + '/' + name,
  383. # param.data.cpu().numpy(),
  384. # # metrics_ary[METRICS_PRED_NDX, benHist_mask],
  385. # self.totalTrainingSamples_count,
  386. # # bins=bins,
  387. # )
  388. #
  389. # # print name, param.data
  390. def saveModel(self, type_str, epoch_ndx, isBest=False):
  391. file_path = os.path.join(
  392. 'data-unversioned',
  393. 'part2',
  394. 'models',
  395. self.cli_args.tb_prefix,
  396. '{}_{}_{}.{}.state'.format(
  397. type_str,
  398. self.time_str,
  399. self.cli_args.comment,
  400. self.totalTrainingSamples_count,
  401. )
  402. )
  403. os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
  404. model = self.model
  405. if hasattr(model, 'module'):
  406. model = model.module
  407. state = {
  408. 'model_state': model.state_dict(),
  409. 'model_name': type(model).__name__,
  410. 'optimizer_state' : self.optimizer.state_dict(),
  411. 'optimizer_name': type(self.optimizer).__name__,
  412. 'epoch': epoch_ndx,
  413. 'totalTrainingSamples_count': self.totalTrainingSamples_count,
  414. }
  415. torch.save(state, file_path)
  416. log.debug("Saved model params to {}".format(file_path))
  417. if isBest:
  418. file_path = os.path.join(
  419. 'data-unversioned',
  420. 'part2',
  421. 'models',
  422. self.cli_args.tb_prefix,
  423. '{}_{}_{}.{}.state'.format(
  424. type_str,
  425. self.time_str,
  426. self.cli_args.comment,
  427. 'best',
  428. )
  429. )
  430. torch.save(state, file_path)
  431. log.debug("Saved model params to {}".format(file_path))
  432. if __name__ == '__main__':
  433. sys.exit(LunaTrainingApp().main() or 0)