training.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. import argparse
  2. import datetime
  3. import hashlib
  4. import os
  5. import shutil
  6. import socket
  7. import sys
  8. import numpy as np
  9. from torch.utils.tensorboard import SummaryWriter
  10. import torch
  11. import torch.nn as nn
  12. import torch.optim
  13. from torch.optim import SGD, Adam
  14. from torch.utils.data import DataLoader
  15. from util.util import enumerateWithEstimate
  16. from .dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt
  17. from util.logconf import logging
  18. from .model import UNetWrapper, SegmentationAugmentation
  19. log = logging.getLogger(__name__)
  20. # log.setLevel(logging.WARN)
  21. # log.setLevel(logging.INFO)
  22. log.setLevel(logging.DEBUG)
  23. # Used for computeClassificationLoss and logMetrics to index into metrics_t/metrics_a
  24. # METRICS_LABEL_NDX = 0
  25. METRICS_LOSS_NDX = 1
  26. # METRICS_FN_LOSS_NDX = 2
  27. # METRICS_ALL_LOSS_NDX = 3
  28. # METRICS_PTP_NDX = 4
  29. # METRICS_PFN_NDX = 5
  30. # METRICS_MFP_NDX = 6
  31. METRICS_TP_NDX = 7
  32. METRICS_FN_NDX = 8
  33. METRICS_FP_NDX = 9
  34. METRICS_SIZE = 10
  35. class SegmentationTrainingApp:
  36. def __init__(self, sys_argv=None):
  37. if sys_argv is None:
  38. sys_argv = sys.argv[1:]
  39. parser = argparse.ArgumentParser()
  40. parser.add_argument('--batch-size',
  41. help='Batch size to use for training',
  42. default=16,
  43. type=int,
  44. )
  45. parser.add_argument('--num-workers',
  46. help='Number of worker processes for background data loading',
  47. default=8,
  48. type=int,
  49. )
  50. parser.add_argument('--epochs',
  51. help='Number of epochs to train for',
  52. default=1,
  53. type=int,
  54. )
  55. parser.add_argument('--augmented',
  56. help="Augment the training data.",
  57. action='store_true',
  58. default=False,
  59. )
  60. parser.add_argument('--augment-flip',
  61. help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
  62. action='store_true',
  63. default=False,
  64. )
  65. parser.add_argument('--augment-offset',
  66. help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
  67. action='store_true',
  68. default=False,
  69. )
  70. parser.add_argument('--augment-scale',
  71. help="Augment the training data by randomly increasing or decreasing the size of the candidate.",
  72. action='store_true',
  73. default=False,
  74. )
  75. parser.add_argument('--augment-rotate',
  76. help="Augment the training data by randomly rotating the data around the head-foot axis.",
  77. action='store_true',
  78. default=False,
  79. )
  80. parser.add_argument('--augment-noise',
  81. help="Augment the training data by randomly adding noise to the data.",
  82. action='store_true',
  83. default=False,
  84. )
  85. parser.add_argument('--tb-prefix',
  86. default='p2ch13',
  87. help="Data prefix to use for Tensorboard run. Defaults to chapter.",
  88. )
  89. parser.add_argument('comment',
  90. help="Comment suffix for Tensorboard run.",
  91. nargs='?',
  92. default='none',
  93. )
  94. self.cli_args = parser.parse_args(sys_argv)
  95. self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
  96. self.totalTrainingSamples_count = 0
  97. self.trn_writer = None
  98. self.val_writer = None
  99. self.augmentation_dict = {}
  100. if self.cli_args.augmented or self.cli_args.augment_flip:
  101. self.augmentation_dict['flip'] = True
  102. if self.cli_args.augmented or self.cli_args.augment_offset:
  103. self.augmentation_dict['offset'] = 0.03
  104. if self.cli_args.augmented or self.cli_args.augment_scale:
  105. self.augmentation_dict['scale'] = 0.2
  106. if self.cli_args.augmented or self.cli_args.augment_rotate:
  107. self.augmentation_dict['rotate'] = True
  108. if self.cli_args.augmented or self.cli_args.augment_noise:
  109. self.augmentation_dict['noise'] = 25.0
  110. self.use_cuda = torch.cuda.is_available()
  111. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  112. self.segmentation_model, self.augmentation_model = self.initModel()
  113. self.optimizer = self.initOptimizer()
  114. def initModel(self):
  115. segmentation_model = UNetWrapper(
  116. in_channels=7,
  117. n_classes=1,
  118. depth=3,
  119. wf=4,
  120. padding=True,
  121. batch_norm=True,
  122. up_mode='upconv',
  123. )
  124. augmentation_model = SegmentationAugmentation(**self.augmentation_dict)
  125. if self.use_cuda:
  126. log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
  127. if torch.cuda.device_count() > 1:
  128. segmentation_model = nn.DataParallel(segmentation_model)
  129. augmentation_model = nn.DataParallel(augmentation_model)
  130. segmentation_model = segmentation_model.to(self.device)
  131. augmentation_model = augmentation_model.to(self.device)
  132. return segmentation_model, augmentation_model
  133. def initOptimizer(self):
  134. return Adam(self.segmentation_model.parameters())
  135. # return SGD(self.segmentation_model.parameters(), lr=0.001, momentum=0.99)
  136. def initTrainDl(self):
  137. train_ds = TrainingLuna2dSegmentationDataset(
  138. val_stride=10,
  139. isValSet_bool=False,
  140. contextSlices_count=3,
  141. )
  142. batch_size = self.cli_args.batch_size
  143. if self.use_cuda:
  144. batch_size *= torch.cuda.device_count()
  145. train_dl = DataLoader(
  146. train_ds,
  147. batch_size=batch_size,
  148. num_workers=self.cli_args.num_workers,
  149. pin_memory=self.use_cuda,
  150. )
  151. return train_dl
  152. def initValDl(self):
  153. val_ds = Luna2dSegmentationDataset(
  154. val_stride=10,
  155. isValSet_bool=True,
  156. contextSlices_count=3,
  157. )
  158. batch_size = self.cli_args.batch_size
  159. if self.use_cuda:
  160. batch_size *= torch.cuda.device_count()
  161. val_dl = DataLoader(
  162. val_ds,
  163. batch_size=batch_size,
  164. num_workers=self.cli_args.num_workers,
  165. pin_memory=self.use_cuda,
  166. )
  167. return val_dl
  168. def initTensorboardWriters(self):
  169. if self.trn_writer is None:
  170. log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
  171. self.trn_writer = SummaryWriter(
  172. log_dir=log_dir + '_trn_seg_' + self.cli_args.comment)
  173. self.val_writer = SummaryWriter(
  174. log_dir=log_dir + '_val_seg_' + self.cli_args.comment)
  175. def main(self):
  176. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  177. train_dl = self.initTrainDl()
  178. val_dl = self.initValDl()
  179. best_score = 0.0
  180. self.validation_cadence = 5
  181. for epoch_ndx in range(1, self.cli_args.epochs + 1):
  182. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  183. epoch_ndx,
  184. self.cli_args.epochs,
  185. len(train_dl),
  186. len(val_dl),
  187. self.cli_args.batch_size,
  188. (torch.cuda.device_count() if self.use_cuda else 1),
  189. ))
  190. trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
  191. self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
  192. if epoch_ndx == 1 or epoch_ndx % self.validation_cadence == 0:
  193. # if validation is wanted
  194. valMetrics_t = self.doValidation(epoch_ndx, val_dl)
  195. score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
  196. best_score = max(score, best_score)
  197. self.saveModel('seg', epoch_ndx, score == best_score)
  198. self.logImages(epoch_ndx, 'trn', train_dl)
  199. self.logImages(epoch_ndx, 'val', val_dl)
  200. self.trn_writer.close()
  201. self.val_writer.close()
  202. def doTraining(self, epoch_ndx, train_dl):
  203. trnMetrics_g = torch.zeros(METRICS_SIZE, len(train_dl.dataset), device=self.device)
  204. self.segmentation_model.train()
  205. train_dl.dataset.shuffleSamples()
  206. batch_iter = enumerateWithEstimate(
  207. train_dl,
  208. "E{} Training".format(epoch_ndx),
  209. start_ndx=train_dl.num_workers,
  210. )
  211. for batch_ndx, batch_tup in batch_iter:
  212. self.optimizer.zero_grad()
  213. loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trnMetrics_g)
  214. loss_var.backward()
  215. self.optimizer.step()
  216. self.totalTrainingSamples_count += trnMetrics_g.size(1)
  217. return trnMetrics_g.to('cpu')
  218. def doValidation(self, epoch_ndx, val_dl):
  219. with torch.no_grad():
  220. valMetrics_g = torch.zeros(METRICS_SIZE, len(val_dl.dataset), device=self.device)
  221. self.segmentation_model.eval()
  222. batch_iter = enumerateWithEstimate(
  223. val_dl,
  224. "E{} Validation ".format(epoch_ndx),
  225. start_ndx=val_dl.num_workers,
  226. )
  227. for batch_ndx, batch_tup in batch_iter:
  228. self.computeBatchLoss(batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)
  229. return valMetrics_g.to('cpu')
  230. def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g,
  231. classificationThreshold=0.5):
  232. input_t, label_t, series_list, _slice_ndx_list = batch_tup
  233. input_g = input_t.to(self.device, non_blocking=True)
  234. label_g = label_t.to(self.device, non_blocking=True)
  235. if self.segmentation_model.training and self.augmentation_dict:
  236. input_g, label_g = self.augmentation_model(input_g, label_g)
  237. prediction_g = self.segmentation_model(input_g)
  238. diceLoss_g = self.diceLoss(prediction_g, label_g)
  239. fnLoss_g = self.diceLoss(prediction_g * label_g, label_g)
  240. start_ndx = batch_ndx * batch_size
  241. end_ndx = start_ndx + input_t.size(0)
  242. with torch.no_grad():
  243. predictionBool_g = (prediction_g[:, 0:1]
  244. > classificationThreshold).to(torch.float32)
  245. tp = ( predictionBool_g * label_g).sum(dim=[1,2,3])
  246. fn = ((1 - predictionBool_g) * label_g).sum(dim=[1,2,3])
  247. fp = ( predictionBool_g * (~label_g)).sum(dim=[1,2,3])
  248. metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_g
  249. metrics_g[METRICS_TP_NDX, start_ndx:end_ndx] = tp
  250. metrics_g[METRICS_FN_NDX, start_ndx:end_ndx] = fn
  251. metrics_g[METRICS_FP_NDX, start_ndx:end_ndx] = fp
  252. return diceLoss_g.mean() + fnLoss_g.mean() * 8
  253. def diceLoss(self, prediction_g, label_g, epsilon=1):
  254. diceLabel_g = label_g.sum(dim=[1,2,3])
  255. dicePrediction_g = prediction_g.sum(dim=[1,2,3])
  256. diceCorrect_g = (prediction_g * label_g).sum(dim=[1,2,3])
  257. diceRatio_g = (2 * diceCorrect_g + epsilon) \
  258. / (dicePrediction_g + diceLabel_g + epsilon)
  259. return 1 - diceRatio_g
  260. def logImages(self, epoch_ndx, mode_str, dl):
  261. self.segmentation_model.eval()
  262. images = sorted(dl.dataset.series_list)[:12]
  263. for series_ndx, series_uid in enumerate(images):
  264. ct = getCt(series_uid)
  265. for slice_ndx in range(6):
  266. ct_ndx = slice_ndx * (ct.hu_a.shape[0] - 1) // 5
  267. sample_tup = dl.dataset.getitem_fullSlice(series_uid, ct_ndx)
  268. ct_t, label_t, series_uid, ct_ndx = sample_tup
  269. input_g = ct_t.to(self.device).unsqueeze(0)
  270. label_g = pos_g = label_t.to(self.device).unsqueeze(0)
  271. prediction_g = self.segmentation_model(input_g)[0]
  272. prediction_a = prediction_g.to('cpu').detach().numpy()[0] > 0.5
  273. label_a = label_g.cpu().numpy()[0][0] > 0.5
  274. ct_t[:-1,:,:] /= 2000
  275. ct_t[:-1,:,:] += 0.5
  276. ctSlice_a = ct_t[dl.dataset.contextSlices_count].numpy()
  277. image_a = np.zeros((512, 512, 3), dtype=np.float32)
  278. image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
  279. image_a[:,:,0] += prediction_a & (1 - label_a)
  280. image_a[:,:,0] += (1 - prediction_a) & label_a
  281. image_a[:,:,1] += ((1 - prediction_a) & label_a) * 0.5
  282. image_a[:,:,1] += prediction_a & label_a
  283. image_a *= 0.5
  284. image_a.clip(0, 1, image_a)
  285. writer = getattr(self, mode_str + '_writer')
  286. writer.add_image(
  287. f'{mode_str}/{series_ndx}_prediction_{slice_ndx}',
  288. image_a,
  289. self.totalTrainingSamples_count,
  290. dataformats='HWC',
  291. )
  292. if epoch_ndx == 1:
  293. image_a = np.zeros((512, 512, 3), dtype=np.float32)
  294. image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
  295. # image_a[:,:,0] += (1 - label_a) & lung_a # Red
  296. image_a[:,:,1] += label_a # Green
  297. # image_a[:,:,2] += neg_a # Blue
  298. image_a *= 0.5
  299. image_a[image_a < 0] = 0
  300. image_a[image_a > 1] = 1
  301. writer.add_image(
  302. '{}/{}_label_{}'.format(
  303. mode_str,
  304. series_ndx,
  305. slice_ndx,
  306. ),
  307. image_a,
  308. self.totalTrainingSamples_count,
  309. dataformats='HWC',
  310. )
  311. # This flush prevents TB from getting confused about which
  312. # data item belongs where.
  313. writer.flush()
  314. def logMetrics(self, epoch_ndx, mode_str, metrics_t):
  315. log.info("E{} {}".format(
  316. epoch_ndx,
  317. type(self).__name__,
  318. ))
  319. metrics_a = metrics_t.detach().numpy()
  320. sum_a = metrics_a.sum(axis=1)
  321. assert np.isfinite(metrics_a).all()
  322. allLabel_count = sum_a[METRICS_TP_NDX] + sum_a[METRICS_FN_NDX]
  323. metrics_dict = {}
  324. metrics_dict['loss/all'] = metrics_a[METRICS_LOSS_NDX].mean()
  325. metrics_dict['percent_all/tp'] = \
  326. sum_a[METRICS_TP_NDX] / (allLabel_count or 1) * 100
  327. metrics_dict['percent_all/fn'] = \
  328. sum_a[METRICS_FN_NDX] / (allLabel_count or 1) * 100
  329. metrics_dict['percent_all/fp'] = \
  330. sum_a[METRICS_FP_NDX] / (allLabel_count or 1) * 100
  331. precision = metrics_dict['pr/precision'] = sum_a[METRICS_TP_NDX] \
  332. / ((sum_a[METRICS_TP_NDX] + sum_a[METRICS_FP_NDX]) or 1)
  333. recall = metrics_dict['pr/recall'] = sum_a[METRICS_TP_NDX] \
  334. / ((sum_a[METRICS_TP_NDX] + sum_a[METRICS_FN_NDX]) or 1)
  335. metrics_dict['pr/f1_score'] = 2 * (precision * recall) \
  336. / ((precision + recall) or 1)
  337. log.info(("E{} {:8} "
  338. + "{loss/all:.4f} loss, "
  339. + "{pr/precision:.4f} precision, "
  340. + "{pr/recall:.4f} recall, "
  341. + "{pr/f1_score:.4f} f1 score"
  342. ).format(
  343. epoch_ndx,
  344. mode_str,
  345. **metrics_dict,
  346. ))
  347. log.info(("E{} {:8} "
  348. + "{loss/all:.4f} loss, "
  349. + "{percent_all/tp:-5.1f}% tp, {percent_all/fn:-5.1f}% fn, {percent_all/fp:-9.1f}% fp"
  350. ).format(
  351. epoch_ndx,
  352. mode_str + '_all',
  353. **metrics_dict,
  354. ))
  355. self.initTensorboardWriters()
  356. writer = getattr(self, mode_str + '_writer')
  357. prefix_str = 'seg_'
  358. for key, value in metrics_dict.items():
  359. writer.add_scalar(prefix_str + key, value, self.totalTrainingSamples_count)
  360. writer.flush()
  361. score = metrics_dict['pr/recall']
  362. return score
  363. # def logModelMetrics(self, model):
  364. # writer = getattr(self, 'trn_writer')
  365. #
  366. # model = getattr(model, 'module', model)
  367. #
  368. # for name, param in model.named_parameters():
  369. # if param.requires_grad:
  370. # min_data = float(param.data.min())
  371. # max_data = float(param.data.max())
  372. # max_extent = max(abs(min_data), abs(max_data))
  373. #
  374. # # bins = [x/50*max_extent for x in range(-50, 51)]
  375. #
  376. # writer.add_histogram(
  377. # name.rsplit('.', 1)[-1] + '/' + name,
  378. # param.data.cpu().numpy(),
  379. # # metrics_a[METRICS_PRED_NDX, negHist_mask],
  380. # self.totalTrainingSamples_count,
  381. # # bins=bins,
  382. # )
  383. #
  384. # # print name, param.data
  385. def saveModel(self, type_str, epoch_ndx, isBest=False):
  386. file_path = os.path.join(
  387. 'data-unversioned',
  388. 'part2',
  389. 'models',
  390. self.cli_args.tb_prefix,
  391. '{}_{}_{}.{}.state'.format(
  392. type_str,
  393. self.time_str,
  394. self.cli_args.comment,
  395. self.totalTrainingSamples_count,
  396. )
  397. )
  398. os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
  399. model = self.segmentation_model
  400. if isinstance(model, torch.nn.DataParallel):
  401. model = model.module
  402. state = {
  403. 'sys_argv': sys.argv,
  404. 'time': str(datetime.datetime.now()),
  405. 'model_state': model.state_dict(),
  406. 'model_name': type(model).__name__,
  407. 'optimizer_state' : self.optimizer.state_dict(),
  408. 'optimizer_name': type(self.optimizer).__name__,
  409. 'epoch': epoch_ndx,
  410. 'totalTrainingSamples_count': self.totalTrainingSamples_count,
  411. }
  412. torch.save(state, file_path)
  413. log.info("Saved model params to {}".format(file_path))
  414. if isBest:
  415. best_path = os.path.join(
  416. 'data-unversioned', 'part2', 'models',
  417. self.cli_args.tb_prefix,
  418. f'{type_str}_{self.time_str}_{self.cli_args.comment}.best.state')
  419. shutil.copyfile(file_path, best_path)
  420. log.info("Saved model params to {}".format(best_path))
  421. with open(file_path, 'rb') as f:
  422. log.info("SHA1: " + hashlib.sha1(f.read()).hexdigest())
  423. if __name__ == '__main__':
  424. SegmentationTrainingApp().main()