training.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import argparse
  2. import datetime
  3. import os
  4. import sys
  5. import numpy as np
  6. from tensorboardX import SummaryWriter
  7. import torch
  8. import torch.nn as nn
  9. from torch.optim import SGD
  10. from torch.utils.data import DataLoader
  11. from util.util import enumerateWithEstimate
  12. from .dsets import LunaDataset
  13. from util.logconf import logging
  14. from .model import LunaModel
  15. log = logging.getLogger(__name__)
  16. # log.setLevel(logging.WARN)
  17. log.setLevel(logging.INFO)
  18. # log.setLevel(logging.DEBUG)
  19. # Used for computeBatchLoss and logMetrics to index into metrics_tensor/metrics_ary
  20. METRICS_LABEL_NDX=0
  21. METRICS_PRED_NDX=1
  22. METRICS_LOSS_NDX=2
  23. class LunaTrainingApp(object):
  24. def __init__(self, sys_argv=None):
  25. if sys_argv is None:
  26. sys_argv = sys.argv[1:]
  27. parser = argparse.ArgumentParser()
  28. parser.add_argument('--batch-size',
  29. help='Batch size to use for training',
  30. default=32,
  31. type=int,
  32. )
  33. parser.add_argument('--num-workers',
  34. help='Number of worker processes for background data loading',
  35. default=8,
  36. type=int,
  37. )
  38. parser.add_argument('--epochs',
  39. help='Number of epochs to train for',
  40. default=1,
  41. type=int,
  42. )
  43. self.cli_args = parser.parse_args(sys_argv)
  44. self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
  45. def main(self):
  46. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  47. self.use_cuda = torch.cuda.is_available()
  48. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  49. self.model = LunaModel()
  50. if self.use_cuda:
  51. if torch.cuda.device_count() > 1:
  52. self.model = nn.DataParallel(self.model)
  53. self.model = self.model.to(self.device)
  54. self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9)
  55. train_dl = DataLoader(
  56. LunaDataset(
  57. test_stride=10,
  58. isTestSet_bool=False,
  59. ),
  60. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  61. num_workers=self.cli_args.num_workers,
  62. pin_memory=self.use_cuda,
  63. )
  64. test_dl = DataLoader(
  65. LunaDataset(
  66. test_stride=10,
  67. isTestSet_bool=True,
  68. ),
  69. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  70. num_workers=self.cli_args.num_workers,
  71. pin_memory=self.use_cuda,
  72. )
  73. for epoch_ndx in range(1, self.cli_args.epochs + 1):
  74. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  75. epoch_ndx,
  76. self.cli_args.epochs,
  77. len(train_dl),
  78. len(test_dl),
  79. self.cli_args.batch_size,
  80. (torch.cuda.device_count() if self.use_cuda else 1),
  81. ))
  82. # Training loop, very similar to below
  83. self.model.train()
  84. trainingMetrics_tensor = torch.zeros(3, len(train_dl.dataset), 1)
  85. batch_iter = enumerateWithEstimate(
  86. train_dl,
  87. "E{} Training".format(epoch_ndx),
  88. start_ndx=train_dl.num_workers,
  89. )
  90. for batch_ndx, batch_tup in batch_iter:
  91. self.optimizer.zero_grad()
  92. loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
  93. loss_var.backward()
  94. self.optimizer.step()
  95. del loss_var
  96. # Testing loop, very similar to above, but simplified
  97. with torch.no_grad():
  98. self.model.eval()
  99. testingMetrics_tensor = torch.zeros(3, len(test_dl.dataset), 1)
  100. batch_iter = enumerateWithEstimate(
  101. test_dl,
  102. "E{} Testing ".format(epoch_ndx),
  103. start_ndx=test_dl.num_workers,
  104. )
  105. for batch_ndx, batch_tup in batch_iter:
  106. self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
  107. self.logMetrics(epoch_ndx, trainingMetrics_tensor, testingMetrics_tensor)
  108. def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
  109. input_tensor, label_tensor, _series_list, _center_list = batch_tup
  110. input_devtensor = input_tensor.to(self.device)
  111. label_devtensor = label_tensor.to(self.device)
  112. prediction_devtensor = self.model(input_devtensor)
  113. loss_devtensor = nn.MSELoss(reduction='none')(prediction_devtensor, label_devtensor)
  114. start_ndx = batch_ndx * batch_size
  115. end_ndx = start_ndx + label_tensor.size(0)
  116. metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_tensor
  117. metrics_tensor[METRICS_PRED_NDX, start_ndx:end_ndx] = prediction_devtensor.to('cpu')
  118. metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor.to('cpu')
  119. # TODO: replace with torch.autograd.detect_anomaly
  120. # assert np.isfinite(metrics_tensor).all()
  121. return loss_devtensor.mean()
  122. def logMetrics(self,
  123. epoch_ndx,
  124. trainingMetrics_tensor,
  125. testingMetrics_tensor,
  126. classificationThreshold_float=0.5,
  127. ):
  128. log.info("E{} {}".format(
  129. epoch_ndx,
  130. type(self).__name__,
  131. ))
  132. for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
  133. metrics_ary = metrics_tensor.detach().numpy()[:,:,0]
  134. assert np.isfinite(metrics_ary).all()
  135. benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= classificationThreshold_float
  136. benPred_mask = metrics_ary[METRICS_PRED_NDX] <= classificationThreshold_float
  137. malLabel_mask = ~benLabel_mask
  138. malPred_mask = ~benPred_mask
  139. benLabel_count = benLabel_mask.sum()
  140. malLabel_count = malLabel_mask.sum()
  141. benCorrect_count = (benLabel_mask & benPred_mask).sum()
  142. malCorrect_count = (malLabel_mask & malPred_mask).sum()
  143. metrics_dict = {}
  144. metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
  145. metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, benLabel_mask].mean()
  146. metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, malLabel_mask].mean()
  147. metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_ary.shape[1] * 100
  148. metrics_dict['correct/ben'] = (benCorrect_count) / benLabel_count * 100
  149. metrics_dict['correct/mal'] = (malCorrect_count) / malLabel_count * 100
  150. log.info(("E{} {:8} "
  151. + "{loss/all:.4f} loss, "
  152. + "{correct/all:-5.1f}% correct"
  153. ).format(
  154. epoch_ndx,
  155. mode_str,
  156. **metrics_dict,
  157. ))
  158. log.info(("E{} {:8} "
  159. + "{loss/ben:.4f} loss, "
  160. + "{correct/ben:-5.1f}% correct").format(
  161. epoch_ndx,
  162. mode_str + '_ben',
  163. **metrics_dict,
  164. ))
  165. log.info(("E{} {:8} "
  166. + "{loss/mal:.4f} loss, "
  167. + "{correct/mal:-5.1f}% correct").format(
  168. epoch_ndx,
  169. mode_str + '_mal',
  170. **metrics_dict,
  171. ))
  172. if __name__ == '__main__':
  173. sys.exit(LunaTrainingApp().main() or 0)