training.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. import argparse
  2. import datetime
  3. import os
  4. import sys
  5. import numpy as np
  6. from tensorboardX import SummaryWriter
  7. import torch
  8. import torch.nn as nn
  9. from torch.optim import SGD
  10. from torch.utils.data import DataLoader
  11. from util.util import enumerateWithEstimate
  12. from .dsets import LunaDataset
  13. from util.logconf import logging
  14. from .model import LunaModel
  15. log = logging.getLogger(__name__)
  16. # log.setLevel(logging.WARN)
  17. log.setLevel(logging.INFO)
  18. # log.setLevel(logging.DEBUG)
  19. # Used for computeBatchLoss and logMetrics to index into metrics_tensor/metrics_ary
  20. METRICS_LABEL_NDX=0
  21. METRICS_PRED_NDX=1
  22. METRICS_LOSS_NDX=2
  23. class LunaTrainingApp(object):
  24. def __init__(self, sys_argv=None):
  25. if sys_argv is None:
  26. sys_argv = sys.argv[1:]
  27. parser = argparse.ArgumentParser()
  28. parser.add_argument('--batch-size',
  29. help='Batch size to use for training',
  30. default=32,
  31. type=int,
  32. )
  33. parser.add_argument('--num-workers',
  34. help='Number of worker processes for background data loading',
  35. default=8,
  36. type=int,
  37. )
  38. parser.add_argument('--epochs',
  39. help='Number of epochs to train for',
  40. default=1,
  41. type=int,
  42. )
  43. parser.add_argument('--balanced',
  44. help="Balance the training data to half benign, half malignant.",
  45. action='store_true',
  46. default=False,
  47. )
  48. parser.add_argument('--tb-prefix',
  49. default='p2ch10',
  50. help="Data prefix to use for Tensorboard run. Defaults to chapter.",
  51. )
  52. parser.add_argument('comment',
  53. help="Comment suffix for Tensorboard run.",
  54. nargs='?',
  55. default='none',
  56. )
  57. self.cli_args = parser.parse_args(sys_argv)
  58. self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
  59. def main(self):
  60. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  61. self.use_cuda = torch.cuda.is_available()
  62. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  63. self.totalTrainingSamples_count = 0
  64. self.model = LunaModel()
  65. if self.use_cuda:
  66. if torch.cuda.device_count() > 1:
  67. self.model = nn.DataParallel(self.model)
  68. self.model = self.model.to(self.device)
  69. self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9)
  70. train_dl = DataLoader(
  71. LunaDataset(
  72. test_stride=10,
  73. isTestSet_bool=False,
  74. ratio_int=int(self.cli_args.balanced),
  75. ),
  76. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  77. num_workers=self.cli_args.num_workers,
  78. pin_memory=self.use_cuda,
  79. )
  80. test_dl = DataLoader(
  81. LunaDataset(
  82. test_stride=10,
  83. isTestSet_bool=True,
  84. ),
  85. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  86. num_workers=self.cli_args.num_workers,
  87. pin_memory=self.use_cuda,
  88. )
  89. for epoch_ndx in range(1, self.cli_args.epochs + 1):
  90. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  91. epoch_ndx,
  92. self.cli_args.epochs,
  93. len(train_dl),
  94. len(test_dl),
  95. self.cli_args.batch_size,
  96. (torch.cuda.device_count() if self.use_cuda else 1),
  97. ))
  98. # Training loop, very similar to below
  99. self.model.train()
  100. trainingMetrics_tensor = torch.zeros(3, len(train_dl.dataset), 1)
  101. train_dl.dataset.shuffleSamples()
  102. batch_iter = enumerateWithEstimate(
  103. train_dl,
  104. "E{} Training".format(epoch_ndx),
  105. start_ndx=train_dl.num_workers,
  106. )
  107. for batch_ndx, batch_tup in batch_iter:
  108. self.optimizer.zero_grad()
  109. loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
  110. loss_var.backward()
  111. self.optimizer.step()
  112. del loss_var
  113. # Testing loop, very similar to above, but simplified
  114. with torch.no_grad():
  115. self.model.eval()
  116. testingMetrics_tensor = torch.zeros(3, len(test_dl.dataset), 1)
  117. batch_iter = enumerateWithEstimate(
  118. test_dl,
  119. "E{} Testing ".format(epoch_ndx),
  120. start_ndx=test_dl.num_workers,
  121. )
  122. for batch_ndx, batch_tup in batch_iter:
  123. self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
  124. self.logMetrics(epoch_ndx, trainingMetrics_tensor, testingMetrics_tensor)
  125. if hasattr(self, 'trn_writer'):
  126. self.trn_writer.close()
  127. self.tst_writer.close()
  128. def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
  129. input_tensor, label_tensor, _series_list, _center_list = batch_tup
  130. input_devtensor = input_tensor.to(self.device)
  131. label_devtensor = label_tensor.to(self.device)
  132. prediction_devtensor = self.model(input_devtensor)
  133. loss_devtensor = nn.MSELoss(reduction='none')(prediction_devtensor, label_devtensor)
  134. start_ndx = batch_ndx * batch_size
  135. end_ndx = start_ndx + label_tensor.size(0)
  136. metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_tensor
  137. metrics_tensor[METRICS_PRED_NDX, start_ndx:end_ndx] = prediction_devtensor.to('cpu')
  138. metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor.to('cpu')
  139. # TODO: replace with torch.autograd.detect_anomaly
  140. # assert np.isfinite(metrics_tensor).all()
  141. return loss_devtensor.mean()
  142. def logMetrics(self,
  143. epoch_ndx,
  144. trainingMetrics_tensor,
  145. testingMetrics_tensor,
  146. classificationThreshold_float=0.5,
  147. ):
  148. log.info("E{} {}".format(
  149. epoch_ndx,
  150. type(self).__name__,
  151. ))
  152. if epoch_ndx == 2:
  153. log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
  154. self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_' + self.cli_args.comment)
  155. self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_' + self.cli_args.comment)
  156. self.totalTrainingSamples_count += trainingMetrics_tensor.size(1)
  157. for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
  158. metrics_ary = metrics_tensor.cpu().detach().numpy()[:,:,0]
  159. assert np.isfinite(metrics_ary).all()
  160. benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= classificationThreshold_float
  161. benPred_mask = metrics_ary[METRICS_PRED_NDX] <= classificationThreshold_float
  162. malLabel_mask = ~benLabel_mask
  163. malPred_mask = ~benPred_mask
  164. benLabel_count = benLabel_mask.sum()
  165. malLabel_count = malLabel_mask.sum()
  166. trueNeg_count = benCorrect_count = (benLabel_mask & benPred_mask).sum()
  167. truePos_count = malCorrect_count = (malLabel_mask & malPred_mask).sum()
  168. falsePos_count = benLabel_count - benCorrect_count
  169. falseNeg_count = malLabel_count - malCorrect_count
  170. metrics_dict = {}
  171. metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
  172. metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, benLabel_mask].mean()
  173. metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, malLabel_mask].mean()
  174. metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_ary.shape[1] * 100
  175. metrics_dict['correct/ben'] = (benCorrect_count) / benLabel_count * 100
  176. metrics_dict['correct/mal'] = (malCorrect_count) / malLabel_count * 100
  177. precision = metrics_dict['pr/precision'] = truePos_count / (truePos_count + falsePos_count)
  178. recall = metrics_dict['pr/recall'] = truePos_count / (truePos_count + falseNeg_count)
  179. metrics_dict['pr/f1_score'] = 2 * (precision * recall) / (precision + recall)
  180. log.info(("E{} {:8} "
  181. + "{loss/all:.4f} loss, "
  182. + "{correct/all:-5.1f}% correct, "
  183. + "{pr/precision:.4f} precision, "
  184. + "{pr/recall:.4f} recall, "
  185. + "{pr/f1_score:.4f} f1 score"
  186. ).format(
  187. epoch_ndx,
  188. mode_str,
  189. **metrics_dict,
  190. ))
  191. log.info(("E{} {:8} "
  192. + "{loss/ben:.4f} loss, "
  193. + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})").format(
  194. epoch_ndx,
  195. mode_str + '_ben',
  196. benCorrect_count=benCorrect_count,
  197. benLabel_count=benLabel_count,
  198. **metrics_dict,
  199. ))
  200. log.info(("E{} {:8} "
  201. + "{loss/mal:.4f} loss, "
  202. + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})").format(
  203. epoch_ndx,
  204. mode_str + '_mal',
  205. malCorrect_count=malCorrect_count,
  206. malLabel_count=malLabel_count,
  207. **metrics_dict,
  208. ))
  209. if epoch_ndx > 1:
  210. writer = getattr(self, mode_str + '_writer')
  211. for key, value in metrics_dict.items():
  212. writer.add_scalar(key, value, self.totalTrainingSamples_count)
  213. writer.add_pr_curve(
  214. 'pr',
  215. metrics_ary[METRICS_LABEL_NDX],
  216. metrics_ary[METRICS_PRED_NDX],
  217. self.totalTrainingSamples_count,
  218. )
  219. benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
  220. malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
  221. bins = [x/50.0 for x in range(51)]
  222. writer.add_histogram(
  223. 'is_ben',
  224. metrics_ary[METRICS_PRED_NDX, benHist_mask],
  225. self.totalTrainingSamples_count,
  226. bins=bins,
  227. )
  228. writer.add_histogram(
  229. 'is_mal',
  230. metrics_ary[METRICS_PRED_NDX, malHist_mask],
  231. self.totalTrainingSamples_count,
  232. bins=bins,
  233. )
  234. if __name__ == '__main__':
  235. sys.exit(LunaTrainingApp().main() or 0)