training.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. import argparse
  2. import datetime
  3. import os
  4. import sys
  5. import numpy as np
  6. from tensorboardX import SummaryWriter
  7. import torch
  8. import torch.nn as nn
  9. from torch.autograd import Variable
  10. from torch.optim import SGD
  11. from torch.utils.data import DataLoader
  12. from util.util import enumerateWithEstimate
  13. from .dsets import LunaDataset
  14. from util.logconf import logging
  15. from .model import LunaModel
  16. log = logging.getLogger(__name__)
  17. # log.setLevel(logging.WARN)
  18. log.setLevel(logging.INFO)
  19. # log.setLevel(logging.DEBUG)
  20. # Used for metrics_ary index 0
  21. LABEL=0
  22. PRED=1
  23. LOSS=2
  24. # ...
  25. class LunaTrainingApp(object):
  26. @classmethod
  27. def __init__(self, sys_argv=None):
  28. if sys_argv is None:
  29. sys_argv = sys.argv[1:]
  30. parser = argparse.ArgumentParser()
  31. parser.add_argument('--batch-size',
  32. help='Batch size to use for training',
  33. default=256,
  34. type=int,
  35. )
  36. parser.add_argument('--num-workers',
  37. help='Number of worker processes for background data loading',
  38. default=8,
  39. type=int,
  40. )
  41. parser.add_argument('--epochs',
  42. help='Number of epochs to train for',
  43. default=10,
  44. type=int,
  45. )
  46. parser.add_argument('--layers',
  47. help='Number of layers to the model',
  48. default=3,
  49. type=int,
  50. )
  51. parser.add_argument('--channels',
  52. help="Number of channels for the first layer's convolutions to the model (doubles each layer)",
  53. default=8,
  54. type=int,
  55. )
  56. parser.add_argument('--balanced',
  57. help="Balance the training data to half benign, half malignant.",
  58. action='store_true',
  59. default=False,
  60. )
  61. parser.add_argument('--scaled',
  62. help="Scale the CT chunks to square voxels.",
  63. action='store_true',
  64. default=False,
  65. )
  66. parser.add_argument('--augmented',
  67. help="Augment the training data (implies --scaled).",
  68. action='store_true',
  69. default=False,
  70. )
  71. parser.add_argument('--tb-prefix',
  72. help="Data prefix to use for Tensorboard. Defaults to chapter.",
  73. default='p2ch4',
  74. )
  75. self.cli_args = parser.parse_args(sys_argv)
  76. def main(self):
  77. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  78. self.train_dl = DataLoader(
  79. LunaDataset(
  80. test_stride=10,
  81. isTestSet_bool=False,
  82. balanced_bool=self.cli_args.balanced,
  83. scaled_bool=self.cli_args.scaled or self.cli_args.augmented,
  84. augmented_bool=self.cli_args.augmented,
  85. ),
  86. batch_size=self.cli_args.batch_size * torch.cuda.device_count(),
  87. num_workers=self.cli_args.num_workers,
  88. pin_memory=True,
  89. )
  90. self.test_dl = DataLoader(
  91. LunaDataset(
  92. test_stride=10,
  93. isTestSet_bool=True,
  94. scaled_bool=self.cli_args.scaled or self.cli_args.augmented,
  95. # augmented_bool=self.cli_args.augmented,
  96. ),
  97. batch_size=self.cli_args.batch_size * torch.cuda.device_count(),
  98. num_workers=self.cli_args.num_workers,
  99. pin_memory=True,
  100. )
  101. self.model = LunaModel(self.cli_args.layers, 1, self.cli_args.channels)
  102. self.model = nn.DataParallel(self.model)
  103. self.model = self.model.cuda()
  104. self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9)
  105. time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
  106. log_dir = os.path.join('runs', self.cli_args.tb_prefix, time_str)
  107. self.trn_writer = SummaryWriter(log_dir=log_dir + '_train')
  108. self.tst_writer = SummaryWriter(log_dir=log_dir + '_test')
  109. for epoch_ndx in range(1, self.cli_args.epochs + 1):
  110. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  111. epoch_ndx,
  112. self.cli_args.epochs,
  113. len(self.train_dl),
  114. len(self.test_dl),
  115. self.cli_args.batch_size,
  116. torch.cuda.device_count(),
  117. ))
  118. # Training loop, very similar to below
  119. self.model.train()
  120. self.train_dl.dataset.shuffleSamples()
  121. batch_iter = enumerateWithEstimate(
  122. self.train_dl,
  123. "E{} Training".format(epoch_ndx),
  124. start_ndx=self.train_dl.num_workers,
  125. )
  126. trainingMetrics_ary = np.zeros((3, len(self.train_dl.dataset)), dtype=np.float32)
  127. for batch_ndx, batch_tup in batch_iter:
  128. self.optimizer.zero_grad()
  129. loss_var = self.computeBatchLoss(batch_ndx, batch_tup, self.train_dl.batch_size, trainingMetrics_ary)
  130. loss_var.backward()
  131. self.optimizer.step()
  132. del loss_var
  133. # Testing loop, very similar to above, but simplified
  134. # ...
  135. self.model.eval()
  136. self.test_dl.dataset.shuffleSamples()
  137. batch_iter = enumerateWithEstimate(
  138. self.test_dl,
  139. "E{} Testing ".format(epoch_ndx),
  140. start_ndx=self.test_dl.num_workers,
  141. )
  142. testingMetrics_ary = np.zeros((3, len(self.test_dl.dataset)), dtype=np.float32)
  143. for batch_ndx, batch_tup in batch_iter:
  144. self.computeBatchLoss(batch_ndx, batch_tup, self.test_dl.batch_size, testingMetrics_ary)
  145. self.logMetrics(epoch_ndx, trainingMetrics_ary, testingMetrics_ary)
  146. self.trn_writer.close()
  147. self.tst_writer.close()
  148. def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_ary):
  149. input_tensor, label_tensor, series_list, center_list = batch_tup
  150. input_var = Variable(input_tensor.cuda())
  151. label_var = Variable(label_tensor.cuda())
  152. prediction_var = self.model(input_var)
  153. # ...
  154. start_ndx = batch_ndx * batch_size
  155. end_ndx = start_ndx + label_tensor.size(0)
  156. metrics_ary[LABEL, start_ndx:end_ndx] = label_tensor.numpy()[:,0,0]
  157. metrics_ary[PRED, start_ndx:end_ndx] = prediction_var.data.cpu().numpy()[:,0]
  158. for sample_ndx in range(label_tensor.size(0)):
  159. subloss_var = nn.MSELoss()(prediction_var[sample_ndx], label_var[sample_ndx])
  160. metrics_ary[LOSS, start_ndx+sample_ndx] = subloss_var.data[0]
  161. del subloss_var
  162. loss_var = nn.MSELoss()(prediction_var, label_var)
  163. return loss_var
  164. def logMetrics(self, epoch_ndx, trainingMetrics_ary, testingMetrics_ary):
  165. log.info("E{} {}".format(
  166. epoch_ndx,
  167. type(self).__name__,
  168. ))
  169. for mode_str, metrics_ary in [('trn', trainingMetrics_ary), ('tst', testingMetrics_ary)]:
  170. pos_mask = metrics_ary[LABEL] > 0.5
  171. neg_mask = ~pos_mask
  172. truePos_count = (metrics_ary[PRED, pos_mask] > 0.5).sum()
  173. trueNeg_count = (metrics_ary[PRED, neg_mask] < 0.5).sum()
  174. falseNeg_count = pos_mask.sum() - truePos_count
  175. falsePos_count = neg_mask.sum() - trueNeg_count
  176. metrics_dict = {}
  177. metrics_dict['pr/precision'] = p = truePos_count / (truePos_count + falsePos_count)
  178. metrics_dict['pr/recall'] = r = truePos_count / (truePos_count + falseNeg_count)
  179. # https://en.wikipedia.org/wiki/F1_score
  180. for n in [0.5, 1, 2]:
  181. metrics_dict['pr/f{}_score'.format(n)] = \
  182. (1 + n**2) * (p * r / (n**2 * p + r))
  183. metrics_dict['loss/all'] = metrics_ary[LOSS].mean()
  184. metrics_dict['loss/ben'] = metrics_ary[LOSS, neg_mask].mean()
  185. metrics_dict['loss/mal'] = metrics_ary[LOSS, pos_mask].mean()
  186. metrics_dict['correct/all'] = (truePos_count + trueNeg_count) / metrics_ary.shape[1] * 100
  187. metrics_dict['correct/ben'] = (trueNeg_count) / neg_mask.sum() * 100
  188. metrics_dict['correct/mal'] = (truePos_count) / pos_mask.sum() * 100
  189. log.info(("E{} {:8} "
  190. + "{loss/all:.4f} loss, "
  191. + "{correct/all:-5.1f}% correct, "
  192. + "{pr/precision:.4f} precision, "
  193. + "{pr/recall:.4f} recall").format(
  194. epoch_ndx,
  195. mode_str,
  196. **metrics_dict,
  197. ))
  198. log.info(("E{} {:8} "
  199. + "{loss/ben:.4f} loss, "
  200. + "{correct/ben:-5.1f}% correct").format(
  201. epoch_ndx,
  202. mode_str + '_ben',
  203. **metrics_dict,
  204. ))
  205. log.info(("E{} {:8} "
  206. + "{loss/mal:.4f} loss, "
  207. + "{correct/mal:-5.1f}% correct").format(
  208. epoch_ndx,
  209. mode_str + '_mal',
  210. **metrics_dict,
  211. ))
  212. writer = getattr(self, mode_str + '_writer')
  213. tb_ndx = epoch_ndx * trainingMetrics_ary.shape[1]
  214. for key, value in metrics_dict.items():
  215. writer.add_scalar(key, value, tb_ndx)
  216. writer.add_pr_curve('pr', metrics_ary[LABEL], metrics_ary[PRED], tb_ndx)
  217. writer.add_histogram('is_mal', metrics_ary[PRED, pos_mask], tb_ndx)
  218. writer.add_histogram('is_ben', metrics_ary[PRED, neg_mask], tb_ndx)
  219. if __name__ == '__main__':
  220. sys.exit(LunaTrainingApp().main() or 0)