training.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import argparse
  2. import sys
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from torch.autograd import Variable
  7. from torch.optim import SGD
  8. from torch.utils.data import DataLoader
  9. from util.util import enumerateWithEstimate
  10. from .dsets import LunaDataset
  11. from util.logconf import logging
  12. from .model import LunaModel
  13. log = logging.getLogger(__name__)
  14. # log.setLevel(logging.WARN)
  15. log.setLevel(logging.INFO)
  16. # log.setLevel(logging.DEBUG)
  17. # Used for metrics_ary index 0
  18. LABEL=0
  19. PRED=1
  20. LOSS=2
  21. # ...
  22. class LunaTrainingApp(object):
  23. @classmethod
  24. def __init__(self, sys_argv=None):
  25. if sys_argv is None:
  26. sys_argv = sys.argv[1:]
  27. parser = argparse.ArgumentParser()
  28. parser.add_argument('--batch-size',
  29. help='Batch size to use for training',
  30. default=256,
  31. type=int,
  32. )
  33. parser.add_argument('--num-workers',
  34. help='Number of worker processes for background data loading',
  35. default=8,
  36. type=int,
  37. )
  38. parser.add_argument('--epochs',
  39. help='Number of epochs to train for',
  40. default=10,
  41. type=int,
  42. )
  43. parser.add_argument('--layers',
  44. help='Number of layers to the model',
  45. default=3,
  46. type=int,
  47. )
  48. parser.add_argument('--channels',
  49. help="Number of channels for the first layer's convolutions to the model (doubles each layer)",
  50. default=8,
  51. type=int,
  52. )
  53. self.cli_args = parser.parse_args(sys_argv)
  54. def main(self):
  55. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  56. self.train_dl = DataLoader(
  57. LunaDataset(
  58. test_stride=10,
  59. isTestSet_bool=False,
  60. ),
  61. batch_size=self.cli_args.batch_size * torch.cuda.device_count(),
  62. num_workers=self.cli_args.num_workers,
  63. pin_memory=True,
  64. )
  65. self.test_dl = DataLoader(
  66. LunaDataset(
  67. test_stride=10,
  68. isTestSet_bool=True,
  69. ),
  70. batch_size=self.cli_args.batch_size * torch.cuda.device_count(),
  71. num_workers=self.cli_args.num_workers,
  72. pin_memory=True,
  73. )
  74. self.model = LunaModel(self.cli_args.layers, 1, self.cli_args.channels)
  75. self.model = nn.DataParallel(self.model)
  76. self.model = self.model.cuda()
  77. self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9)
  78. for epoch_ndx in range(1, self.cli_args.epochs + 1):
  79. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  80. epoch_ndx,
  81. self.cli_args.epochs,
  82. len(self.train_dl),
  83. len(self.test_dl),
  84. self.cli_args.batch_size,
  85. torch.cuda.device_count(),
  86. ))
  87. # Training loop, very similar to below
  88. self.model.train()
  89. batch_iter = enumerateWithEstimate(
  90. self.train_dl,
  91. "E{} Training".format(epoch_ndx),
  92. start_ndx=self.train_dl.num_workers,
  93. )
  94. trainingMetrics_ary = np.zeros((3, len(self.train_dl.dataset)), dtype=np.float32)
  95. for batch_ndx, batch_tup in batch_iter:
  96. self.optimizer.zero_grad()
  97. loss_var = self.computeBatchLoss(batch_ndx, batch_tup, self.train_dl.batch_size, trainingMetrics_ary)
  98. loss_var.backward()
  99. self.optimizer.step()
  100. del loss_var
  101. # Testing loop, very similar to above, but simplified
  102. # ...
  103. self.model.eval()
  104. batch_iter = enumerateWithEstimate(
  105. self.test_dl,
  106. "E{} Testing ".format(epoch_ndx),
  107. start_ndx=self.test_dl.num_workers,
  108. )
  109. testingMetrics_ary = np.zeros((3, len(self.test_dl.dataset)), dtype=np.float32)
  110. for batch_ndx, batch_tup in batch_iter:
  111. self.computeBatchLoss(batch_ndx, batch_tup, self.test_dl.batch_size, testingMetrics_ary)
  112. self.logMetrics(epoch_ndx, trainingMetrics_ary, testingMetrics_ary)
  113. def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_ary):
  114. input_tensor, label_tensor, series_list, center_list = batch_tup
  115. input_var = Variable(input_tensor.cuda())
  116. label_var = Variable(label_tensor.cuda())
  117. prediction_var = self.model(input_var)
  118. # ...
  119. start_ndx = batch_ndx * batch_size
  120. end_ndx = start_ndx + label_tensor.size(0)
  121. metrics_ary[LABEL, start_ndx:end_ndx] = label_tensor.numpy()[:,0,0]
  122. metrics_ary[PRED, start_ndx:end_ndx] = prediction_var.data.cpu().numpy()[:,0]
  123. for sample_ndx in range(label_tensor.size(0)):
  124. subloss_var = nn.MSELoss()(prediction_var[sample_ndx], label_var[sample_ndx])
  125. metrics_ary[LOSS, start_ndx+sample_ndx] = subloss_var.data[0]
  126. del subloss_var
  127. loss_var = nn.MSELoss()(prediction_var, label_var)
  128. return loss_var
  129. def logMetrics(self, epoch_ndx, trainingMetrics_ary, testingMetrics_ary):
  130. log.info("E{} {}".format(
  131. epoch_ndx,
  132. type(self).__name__,
  133. ))
  134. for mode_str, metrics_ary in [('trn', trainingMetrics_ary), ('tst', testingMetrics_ary)]:
  135. pos_mask = metrics_ary[LABEL] > 0.5
  136. neg_mask = ~pos_mask
  137. truePos_count = (metrics_ary[PRED, pos_mask] > 0.5).sum()
  138. trueNeg_count = (metrics_ary[PRED, neg_mask] < 0.5).sum()
  139. metrics_dict = {}
  140. metrics_dict['loss/all'] = metrics_ary[LOSS].mean()
  141. metrics_dict['loss/ben'] = metrics_ary[LOSS, neg_mask].mean()
  142. metrics_dict['loss/mal'] = metrics_ary[LOSS, pos_mask].mean()
  143. metrics_dict['correct/all'] = (truePos_count + trueNeg_count) / metrics_ary.shape[1] * 100
  144. metrics_dict['correct/ben'] = (trueNeg_count) / neg_mask.sum() * 100
  145. metrics_dict['correct/mal'] = (truePos_count) / pos_mask.sum() * 100
  146. log.info("E{} {:8} {loss/all:.4f} loss, {correct/all:-5.1f}% correct".format(
  147. epoch_ndx,
  148. mode_str,
  149. **metrics_dict,
  150. ))
  151. log.info("E{} {:8} {loss/ben:.4f} loss, {correct/ben:-5.1f}% correct".format(
  152. epoch_ndx,
  153. mode_str + '_ben',
  154. **metrics_dict,
  155. ))
  156. log.info("E{} {:8} {loss/mal:.4f} loss, {correct/mal:-5.1f}% correct".format(
  157. epoch_ndx,
  158. mode_str + '_mal',
  159. **metrics_dict,
  160. ))
  161. if __name__ == '__main__':
  162. sys.exit(LunaTrainingApp().main() or 0)