training.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. import argparse
  2. import datetime
  3. import hashlib
  4. import os
  5. import shutil
  6. import socket
  7. import sys
  8. import numpy as np
  9. from torch.utils.tensorboard import SummaryWriter
  10. import torch
  11. import torch.nn as nn
  12. import torch.optim
  13. from torch.optim import SGD, Adam
  14. from torch.utils.data import DataLoader
  15. from util.util import enumerateWithEstimate
  16. from .dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt
  17. from util.logconf import logging
  18. from .model import UNetWrapper, SegmentationAugmentation
  19. log = logging.getLogger(__name__)
  20. # log.setLevel(logging.WARN)
  21. # log.setLevel(logging.INFO)
  22. log.setLevel(logging.DEBUG)
  23. # Used for computeClassificationLoss and logMetrics to index into metrics_t/metrics_a
  24. # METRICS_LABEL_NDX = 0
  25. METRICS_LOSS_NDX = 1
  26. # METRICS_FN_LOSS_NDX = 2
  27. # METRICS_ALL_LOSS_NDX = 3
  28. # METRICS_PTP_NDX = 4
  29. # METRICS_PFN_NDX = 5
  30. # METRICS_MFP_NDX = 6
  31. METRICS_TP_NDX = 7
  32. METRICS_FN_NDX = 8
  33. METRICS_FP_NDX = 9
  34. METRICS_SIZE = 10
  35. class SegmentationTrainingApp:
  36. def __init__(self, sys_argv=None):
  37. if sys_argv is None:
  38. sys_argv = sys.argv[1:]
  39. parser = argparse.ArgumentParser()
  40. parser.add_argument('--batch-size',
  41. help='Batch size to use for training',
  42. default=16,
  43. type=int,
  44. )
  45. parser.add_argument('--num-workers',
  46. help='Number of worker processes for background data loading',
  47. default=8,
  48. type=int,
  49. )
  50. parser.add_argument('--epochs',
  51. help='Number of epochs to train for',
  52. default=1,
  53. type=int,
  54. )
  55. parser.add_argument('--augmented',
  56. help="Augment the training data.",
  57. action='store_true',
  58. default=False,
  59. )
  60. parser.add_argument('--augment-flip',
  61. help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
  62. action='store_true',
  63. default=False,
  64. )
  65. parser.add_argument('--augment-offset',
  66. help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
  67. action='store_true',
  68. default=False,
  69. )
  70. parser.add_argument('--augment-scale',
  71. help="Augment the training data by randomly increasing or decreasing the size of the candidate.",
  72. action='store_true',
  73. default=False,
  74. )
  75. parser.add_argument('--augment-rotate',
  76. help="Augment the training data by randomly rotating the data around the head-foot axis.",
  77. action='store_true',
  78. default=False,
  79. )
  80. parser.add_argument('--augment-noise',
  81. help="Augment the training data by randomly adding noise to the data.",
  82. action='store_true',
  83. default=False,
  84. )
  85. parser.add_argument('--tb-prefix',
  86. default='p2ch13',
  87. help="Data prefix to use for Tensorboard run. Defaults to chapter.",
  88. )
  89. parser.add_argument('comment',
  90. help="Comment suffix for Tensorboard run.",
  91. nargs='?',
  92. default='none',
  93. )
  94. self.cli_args = parser.parse_args(sys_argv)
  95. self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
  96. self.totalTrainingSamples_count = 0
  97. self.trn_writer = None
  98. self.val_writer = None
  99. self.augmentation_dict = {}
  100. if self.cli_args.augmented or self.cli_args.augment_flip:
  101. self.augmentation_dict['flip'] = True
  102. if self.cli_args.augmented or self.cli_args.augment_offset:
  103. self.augmentation_dict['offset'] = 0.03
  104. if self.cli_args.augmented or self.cli_args.augment_scale:
  105. self.augmentation_dict['scale'] = 0.2
  106. if self.cli_args.augmented or self.cli_args.augment_rotate:
  107. self.augmentation_dict['rotate'] = True
  108. if self.cli_args.augmented or self.cli_args.augment_noise:
  109. self.augmentation_dict['noise'] = 25.0
  110. self.use_cuda = torch.cuda.is_available()
  111. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  112. self.segmentation_model, self.augmentation_model = self.initModel()
  113. self.optimizer = self.initOptimizer()
  114. def initModel(self):
  115. segmentation_model = UNetWrapper(
  116. in_channels=7,
  117. n_classes=1,
  118. depth=3,
  119. wf=4,
  120. padding=True,
  121. batch_norm=True,
  122. up_mode='upconv',
  123. )
  124. augmentation_model = SegmentationAugmentation(**self.augmentation_dict)
  125. if self.use_cuda:
  126. log.info("Using CUDA with {} devices.".format(torch.cuda.device_count()))
  127. if torch.cuda.device_count() > 1:
  128. segmentation_model = nn.DataParallel(segmentation_model)
  129. augmentation_model = nn.DataParallel(augmentation_model)
  130. segmentation_model = segmentation_model.to(self.device)
  131. augmentation_model = augmentation_model.to(self.device)
  132. return segmentation_model, augmentation_model
  133. def initOptimizer(self):
  134. return Adam(self.segmentation_model.parameters())
  135. # return SGD(self.segmentation_model.parameters(), lr=0.001, momentum=0.99)
  136. def initTrainDl(self):
  137. train_ds = TrainingLuna2dSegmentationDataset(
  138. val_stride=10,
  139. isValSet_bool=False,
  140. contextSlices_count=3,
  141. )
  142. batch_size = self.cli_args.batch_size
  143. if self.use_cuda:
  144. batch_size *= torch.cuda.device_count()
  145. train_dl = DataLoader(
  146. train_ds,
  147. batch_size=batch_size,
  148. num_workers=self.cli_args.num_workers,
  149. pin_memory=self.use_cuda,
  150. )
  151. return train_dl
  152. def initValDl(self):
  153. val_ds = Luna2dSegmentationDataset(
  154. val_stride=10,
  155. isValSet_bool=True,
  156. contextSlices_count=3,
  157. )
  158. batch_size = self.cli_args.batch_size
  159. if self.use_cuda:
  160. batch_size *= torch.cuda.device_count()
  161. val_dl = DataLoader(
  162. val_ds,
  163. batch_size=batch_size,
  164. num_workers=self.cli_args.num_workers,
  165. pin_memory=self.use_cuda,
  166. )
  167. return val_dl
  168. def initTensorboardWriters(self):
  169. if self.trn_writer is None:
  170. log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
  171. self.trn_writer = SummaryWriter(
  172. log_dir=log_dir + '_trn_seg_' + self.cli_args.comment)
  173. self.val_writer = SummaryWriter(
  174. log_dir=log_dir + '_val_seg_' + self.cli_args.comment)
  175. def main(self):
  176. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  177. train_dl = self.initTrainDl()
  178. val_dl = self.initValDl()
  179. best_score = 0.0
  180. self.validation_cadence = 5
  181. for epoch_ndx in range(1, self.cli_args.epochs + 1):
  182. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  183. epoch_ndx,
  184. self.cli_args.epochs,
  185. len(train_dl),
  186. len(val_dl),
  187. self.cli_args.batch_size,
  188. (torch.cuda.device_count() if self.use_cuda else 1),
  189. ))
  190. trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
  191. self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
  192. if epoch_ndx == 1 or epoch_ndx % self.validation_cadence == 0:
  193. valMetrics_t = self.doValidation(epoch_ndx, val_dl)
  194. score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
  195. best_score = max(score, best_score)
  196. self.saveModel('seg', epoch_ndx, score == best_score)
  197. self.logImages(epoch_ndx, 'trn', train_dl)
  198. self.logImages(epoch_ndx, 'val', val_dl)
  199. self.trn_writer.close()
  200. self.val_writer.close()
  201. def doTraining(self, epoch_ndx, train_dl):
  202. trnMetrics_g = torch.zeros(METRICS_SIZE, len(train_dl.dataset), device=self.device)
  203. self.segmentation_model.train()
  204. train_dl.dataset.shuffleSamples()
  205. batch_iter = enumerateWithEstimate(
  206. train_dl,
  207. "E{} Training".format(epoch_ndx),
  208. start_ndx=train_dl.num_workers,
  209. )
  210. for batch_ndx, batch_tup in batch_iter:
  211. self.optimizer.zero_grad()
  212. loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trnMetrics_g)
  213. loss_var.backward()
  214. self.optimizer.step()
  215. self.totalTrainingSamples_count += trnMetrics_g.size(1)
  216. return trnMetrics_g.to('cpu')
  217. def doValidation(self, epoch_ndx, val_dl):
  218. with torch.no_grad():
  219. valMetrics_g = torch.zeros(METRICS_SIZE, len(val_dl.dataset), device=self.device)
  220. self.segmentation_model.eval()
  221. batch_iter = enumerateWithEstimate(
  222. val_dl,
  223. "E{} Validation ".format(epoch_ndx),
  224. start_ndx=val_dl.num_workers,
  225. )
  226. for batch_ndx, batch_tup in batch_iter:
  227. self.computeBatchLoss(batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)
  228. return valMetrics_g.to('cpu')
  229. def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g, classificationThreshold=0.5):
  230. input_t, label_t, series_list, _slice_ndx_list = batch_tup
  231. input_g = input_t.to(self.device, non_blocking=True)
  232. label_g = label_t.to(self.device, non_blocking=True)
  233. if self.segmentation_model.training and self.augmentation_dict:
  234. input_g, label_g = self.augmentation_model(input_g, label_g)
  235. prediction_g = self.segmentation_model(input_g)
  236. diceLoss_g = self.diceLoss(prediction_g, label_g)
  237. fnLoss_g = self.diceLoss(prediction_g * label_g, label_g)
  238. start_ndx = batch_ndx * batch_size
  239. end_ndx = start_ndx + input_t.size(0)
  240. with torch.no_grad():
  241. predictionBool_g = \
  242. (prediction_g[:, 0:1] > classificationThreshold).to(torch.float32)
  243. # metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_list
  244. metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_g
  245. # metrics_g[METRICS_FN_LOSS_NDX, start_ndx:end_ndx] = fnLoss_g
  246. intersectionSum = lambda a, b: (a * b).sum(dim=[1,2,3])
  247. tp = intersectionSum( predictionBool_g, label_g)
  248. fn = intersectionSum(1 - predictionBool_g, label_g)
  249. fp = intersectionSum( predictionBool_g, ~label_g)
  250. metrics_g[METRICS_TP_NDX, start_ndx:end_ndx] = tp
  251. metrics_g[METRICS_FN_NDX, start_ndx:end_ndx] = fn
  252. metrics_g[METRICS_FP_NDX, start_ndx:end_ndx] = fp
  253. del tp, fn, fp
  254. return diceLoss_g.mean() + fnLoss_g.mean() * 2**3# / 2**1
  255. def diceLoss(self, prediction_g, label_g, epsilon=1, p=False):
  256. # log.debug([prediction_g.shape, label_g.shape])
  257. diceLabel_g = label_g.sum(dim=[1,2,3])
  258. dicePrediction_g = prediction_g.sum(dim=[1,2,3])
  259. diceCorrect_g = (prediction_g * label_g).sum(dim=[1,2,3])
  260. diceRatio_g = (2 * diceCorrect_g + epsilon) \
  261. / (dicePrediction_g + diceLabel_g + epsilon)
  262. return 1 - diceRatio_g
  263. def logImages(self, epoch_ndx, mode_str, dl):
  264. self.segmentation_model.eval()
  265. images_iter = sorted(dl.dataset.series_list)[:12]
  266. for series_ndx, series_uid in enumerate(images_iter):
  267. ct = getCt(series_uid)
  268. for slice_ndx in range(6):
  269. ct_ndx = slice_ndx * ct.hu_a.shape[0] // 5
  270. ct_ndx = min(ct_ndx, ct.hu_a.shape[0] - 1)
  271. sample_tup = dl.dataset.getitem_fullSlice(series_uid, ct_ndx)
  272. ct_t, label_t, series_uid, ct_ndx = sample_tup
  273. input_g = ct_t.to(self.device).unsqueeze(0)
  274. label_g = pos_g = label_t.to(self.device).unsqueeze(0)
  275. prediction_g = self.segmentation_model(input_g)[0]
  276. prediction_a = prediction_g.to('cpu').detach().numpy()[0] > 0.5
  277. label_a = label_g.cpu().numpy()[0][0] > 0.5
  278. ct_t[:-1,:,:] /= 1000
  279. ct_t[:-1,:,:] += 1
  280. ct_t[:-1,:,:] /= 2
  281. ctSlice_a = ct_t[dl.dataset.contextSlices_count].numpy()
  282. image_a = np.zeros((512, 512, 3), dtype=np.float32)
  283. image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
  284. image_a[:,:,0] += prediction_a & (1 - label_a)
  285. image_a[:,:,0] += (1 - prediction_a) & label_a
  286. image_a[:,:,1] += ((1 - prediction_a) & label_a) * 0.5
  287. image_a[:,:,1] += prediction_a & label_a
  288. image_a *= 0.5
  289. image_a.clip(0, 1, image_a)
  290. writer = getattr(self, mode_str + '_writer')
  291. writer.add_image(
  292. '{}/{}_prediction_{}'.format(
  293. mode_str,
  294. series_ndx,
  295. slice_ndx,
  296. ),
  297. image_a,
  298. self.totalTrainingSamples_count,
  299. dataformats='HWC',
  300. )
  301. if epoch_ndx == 1:
  302. image_a = np.zeros((512, 512, 3), dtype=np.float32)
  303. image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
  304. # image_a[:,:,0] += (1 - label_a) & lung_a # Red
  305. image_a[:,:,1] += label_a # Green
  306. # image_a[:,:,2] += neg_a # Blue
  307. image_a *= 0.5
  308. image_a[image_a < 0] = 0
  309. image_a[image_a > 1] = 1
  310. writer.add_image(
  311. '{}/{}_label_{}'.format(
  312. mode_str,
  313. series_ndx,
  314. slice_ndx,
  315. ),
  316. image_a,
  317. self.totalTrainingSamples_count,
  318. dataformats='HWC',
  319. )
  320. writer.flush()
  321. def logMetrics(self,
  322. epoch_ndx,
  323. mode_str,
  324. metrics_t,
  325. ):
  326. log.info("E{} {}".format(
  327. epoch_ndx,
  328. type(self).__name__,
  329. ))
  330. metrics_a = metrics_t.detach().numpy()
  331. sum_a = metrics_a.sum(axis=1)
  332. assert np.isfinite(metrics_a).all()
  333. allLabel_count = sum_a[METRICS_TP_NDX] + sum_a[METRICS_FN_NDX]
  334. metrics_dict = {}
  335. metrics_dict['loss/all'] = metrics_a[METRICS_LOSS_NDX].mean()
  336. metrics_dict['percent_all/tp'] = \
  337. sum_a[METRICS_TP_NDX] / (allLabel_count or 1) * 100
  338. metrics_dict['percent_all/fn'] = \
  339. sum_a[METRICS_FN_NDX] / (allLabel_count or 1) * 100
  340. metrics_dict['percent_all/fp'] = \
  341. sum_a[METRICS_FP_NDX] / (allLabel_count or 1) * 100
  342. precision = metrics_dict['pr/precision'] = sum_a[METRICS_TP_NDX] \
  343. / ((sum_a[METRICS_TP_NDX] + sum_a[METRICS_FP_NDX]) or 1)
  344. recall = metrics_dict['pr/recall'] = sum_a[METRICS_TP_NDX] \
  345. / ((sum_a[METRICS_TP_NDX] + sum_a[METRICS_FN_NDX]) or 1)
  346. metrics_dict['pr/f1_score'] = 2 * (precision * recall) \
  347. / ((precision + recall) or 1)
  348. log.info(("E{} {:8} "
  349. + "{loss/all:.4f} loss, "
  350. + "{pr/precision:.4f} precision, "
  351. + "{pr/recall:.4f} recall, "
  352. + "{pr/f1_score:.4f} f1 score"
  353. ).format(
  354. epoch_ndx,
  355. mode_str,
  356. **metrics_dict,
  357. ))
  358. log.info(("E{} {:8} "
  359. + "{loss/all:.4f} loss, "
  360. + "{percent_all/tp:-5.1f}% tp, {percent_all/fn:-5.1f}% fn, {percent_all/fp:-9.1f}% fp"
  361. ).format(
  362. epoch_ndx,
  363. mode_str + '_all',
  364. **metrics_dict,
  365. ))
  366. self.initTensorboardWriters()
  367. writer = getattr(self, mode_str + '_writer')
  368. prefix_str = 'seg_'
  369. for key, value in metrics_dict.items():
  370. writer.add_scalar(prefix_str + key, value, self.totalTrainingSamples_count)
  371. writer.flush()
  372. score = metrics_dict['pr/recall']
  373. return score
  374. # def logModelMetrics(self, model):
  375. # writer = getattr(self, 'trn_writer')
  376. #
  377. # model = getattr(model, 'module', model)
  378. #
  379. # for name, param in model.named_parameters():
  380. # if param.requires_grad:
  381. # min_data = float(param.data.min())
  382. # max_data = float(param.data.max())
  383. # max_extent = max(abs(min_data), abs(max_data))
  384. #
  385. # # bins = [x/50*max_extent for x in range(-50, 51)]
  386. #
  387. # writer.add_histogram(
  388. # name.rsplit('.', 1)[-1] + '/' + name,
  389. # param.data.cpu().numpy(),
  390. # # metrics_a[METRICS_PRED_NDX, negHist_mask],
  391. # self.totalTrainingSamples_count,
  392. # # bins=bins,
  393. # )
  394. #
  395. # # print name, param.data
  396. def saveModel(self, type_str, epoch_ndx, isBest=False):
  397. file_path = os.path.join(
  398. 'data-unversioned',
  399. 'part2',
  400. 'models',
  401. self.cli_args.tb_prefix,
  402. '{}_{}_{}.{}.state'.format(
  403. type_str,
  404. self.time_str,
  405. self.cli_args.comment,
  406. self.totalTrainingSamples_count,
  407. )
  408. )
  409. os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
  410. model = self.segmentation_model
  411. if isinstance(model, torch.nn.DataParallel):
  412. model = model.module
  413. state = {
  414. 'sys_argv': sys.argv,
  415. 'time': str(datetime.datetime.now()),
  416. 'model_state': model.state_dict(),
  417. 'model_name': type(model).__name__,
  418. 'optimizer_state' : self.optimizer.state_dict(),
  419. 'optimizer_name': type(self.optimizer).__name__,
  420. 'epoch': epoch_ndx,
  421. 'totalTrainingSamples_count': self.totalTrainingSamples_count,
  422. }
  423. torch.save(state, file_path)
  424. log.info("Saved model params to {}".format(file_path))
  425. if isBest:
  426. best_path = os.path.join(
  427. 'data-unversioned',
  428. 'part2',
  429. 'models',
  430. self.cli_args.tb_prefix,
  431. '{}_{}_{}.{}.state'.format(
  432. type_str,
  433. self.time_str,
  434. self.cli_args.comment,
  435. 'best',
  436. )
  437. )
  438. shutil.copyfile(file_path, best_path)
  439. log.info("Saved model params to {}".format(best_path))
  440. with open(file_path, 'rb') as f:
  441. log.info("SHA1: " + hashlib.sha1(f.read()).hexdigest())
  442. if __name__ == '__main__':
  443. SegmentationTrainingApp().main()