training.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. import argparse
  2. import datetime
  3. import os
  4. import sys
  5. import numpy as np
  6. from tensorboardX import SummaryWriter
  7. import torch
  8. import torch.nn as nn
  9. from torch.optim import SGD
  10. from torch.utils.data import DataLoader
  11. from util.util import enumerateWithEstimate
  12. from .dsets import LunaDataset
  13. from util.logconf import logging
  14. from .model import LunaModel
  15. log = logging.getLogger(__name__)
  16. # log.setLevel(logging.WARN)
  17. log.setLevel(logging.INFO)
  18. # log.setLevel(logging.DEBUG)
  19. # Used for computeBatchLoss and logMetrics to index into metrics_tensor/metrics_ary
  20. METRICS_LABEL_NDX=0
  21. METRICS_PRED_NDX=1
  22. METRICS_LOSS_NDX=2
  23. METRICS_SIZE = 3
  24. class LunaTrainingApp(object):
  25. def __init__(self, sys_argv=None):
  26. if sys_argv is None:
  27. sys_argv = sys.argv[1:]
  28. parser = argparse.ArgumentParser()
  29. parser.add_argument('--batch-size',
  30. help='Batch size to use for training',
  31. default=32,
  32. type=int,
  33. )
  34. parser.add_argument('--num-workers',
  35. help='Number of worker processes for background data loading',
  36. default=8,
  37. type=int,
  38. )
  39. parser.add_argument('--epochs',
  40. help='Number of epochs to train for',
  41. default=1,
  42. type=int,
  43. )
  44. parser.add_argument('--tb-prefix',
  45. default='p2ch10',
  46. help="Data prefix to use for Tensorboard run. Defaults to chapter.",
  47. )
  48. parser.add_argument('comment',
  49. help="Comment suffix for Tensorboard run.",
  50. nargs='?',
  51. default='none',
  52. )
  53. self.cli_args = parser.parse_args(sys_argv)
  54. self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
  55. self.trn_writer = None
  56. self.tst_writer = None
  57. self.totalTrainingSamples_count = 0
  58. self.use_cuda = torch.cuda.is_available()
  59. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  60. self.model = self.initModel()
  61. self.optimizer = self.initOptimizer()
  62. def initModel(self):
  63. model = LunaModel()
  64. if self.use_cuda:
  65. if torch.cuda.device_count() > 1:
  66. model = nn.DataParallel(model)
  67. model = model.to(self.device)
  68. return model
  69. def initOptimizer(self):
  70. return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
  71. # return Adam(self.model.parameters())
  72. def initTrainDl(self):
  73. train_ds = LunaDataset(
  74. test_stride=10,
  75. isTestSet_bool=False,
  76. )
  77. train_dl = DataLoader(
  78. train_ds,
  79. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  80. num_workers=self.cli_args.num_workers,
  81. pin_memory=self.use_cuda,
  82. )
  83. return train_dl
  84. def initTestDl(self):
  85. test_ds = LunaDataset(
  86. test_stride=10,
  87. isTestSet_bool=True,
  88. )
  89. test_dl = DataLoader(
  90. test_ds,
  91. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  92. num_workers=self.cli_args.num_workers,
  93. pin_memory=self.use_cuda,
  94. )
  95. return test_dl
  96. def initTensorboardWriters(self):
  97. if self.trn_writer is None:
  98. log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
  99. self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_cls_' + self.cli_args.comment)
  100. self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_cls_' + self.cli_args.comment)
  101. # eng::tb_writer[]
  102. def main(self):
  103. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  104. train_dl = self.initTrainDl()
  105. test_dl = self.initTestDl()
  106. self.initTensorboardWriters()
  107. # self.logModelMetrics(self.model)
  108. # best_score = 0.0
  109. for epoch_ndx in range(1, self.cli_args.epochs + 1):
  110. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  111. epoch_ndx,
  112. self.cli_args.epochs,
  113. len(train_dl),
  114. len(test_dl),
  115. self.cli_args.batch_size,
  116. (torch.cuda.device_count() if self.use_cuda else 1),
  117. ))
  118. trnMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
  119. self.logMetrics(epoch_ndx, 'trn', trnMetrics_tensor)
  120. tstMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
  121. self.logMetrics(epoch_ndx, 'tst', tstMetrics_tensor)
  122. if hasattr(self, 'trn_writer'):
  123. self.trn_writer.close()
  124. self.tst_writer.close()
  125. def doTraining(self, epoch_ndx, train_dl):
  126. self.model.train()
  127. trainingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
  128. batch_iter = enumerateWithEstimate(
  129. train_dl,
  130. "E{} Training".format(epoch_ndx),
  131. start_ndx=train_dl.num_workers,
  132. )
  133. for batch_ndx, batch_tup in batch_iter:
  134. self.optimizer.zero_grad()
  135. loss_var = self.computeBatchLoss(
  136. batch_ndx,
  137. batch_tup,
  138. train_dl.batch_size,
  139. trainingMetrics_devtensor
  140. )
  141. loss_var.backward()
  142. self.optimizer.step()
  143. del loss_var
  144. self.totalTrainingSamples_count += trainingMetrics_devtensor.size(1)
  145. return trainingMetrics_devtensor.to('cpu')
  146. def doTesting(self, epoch_ndx, test_dl):
  147. with torch.no_grad():
  148. self.model.eval()
  149. testingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset)).to(self.device)
  150. batch_iter = enumerateWithEstimate(
  151. test_dl,
  152. "E{} Testing ".format(epoch_ndx),
  153. start_ndx=test_dl.num_workers,
  154. )
  155. for batch_ndx, batch_tup in batch_iter:
  156. self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_devtensor)
  157. return testingMetrics_devtensor.to('cpu')
  158. def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_devtensor):
  159. input_tensor, label_tensor, _series_list, _center_list = batch_tup
  160. input_devtensor = input_tensor.to(self.device, non_blocking=True)
  161. label_devtensor = label_tensor.to(self.device, non_blocking=True)
  162. logits_devtensor, probability_devtensor = self.model(input_devtensor)
  163. loss_func = nn.CrossEntropyLoss(reduction='none')
  164. loss_devtensor = loss_func(logits_devtensor, label_devtensor[:,1])
  165. start_ndx = batch_ndx * batch_size
  166. end_ndx = start_ndx + label_tensor.size(0)
  167. metrics_devtensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_devtensor[:,1]
  168. metrics_devtensor[METRICS_PRED_NDX, start_ndx:end_ndx] = probability_devtensor[:,1]
  169. metrics_devtensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor
  170. return loss_devtensor.mean()
  171. def logMetrics(
  172. self,
  173. epoch_ndx,
  174. mode_str,
  175. metrics_tensor,
  176. ):
  177. log.info("E{} {}".format(
  178. epoch_ndx,
  179. type(self).__name__,
  180. ))
  181. metrics_ary = metrics_tensor.cpu().detach().numpy()
  182. # assert np.isfinite(metrics_ary).all()
  183. benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= 0.5
  184. benPred_mask = metrics_ary[METRICS_PRED_NDX] <= 0.5
  185. malLabel_mask = ~benLabel_mask
  186. malPred_mask = ~benPred_mask
  187. benLabel_count = benLabel_mask.sum()
  188. malLabel_count = malLabel_mask.sum()
  189. benCorrect_count = (benLabel_mask & benPred_mask).sum()
  190. malCorrect_count = (malLabel_mask & malPred_mask).sum()
  191. # trueNeg_count = benCorrect_count = (benLabel_mask & benPred_mask).sum()
  192. # truePos_count = malCorrect_count = (malLabel_mask & malPred_mask).sum()
  193. #
  194. # falsePos_count = benLabel_count - benCorrect_count
  195. # falseNeg_count = malLabel_count - malCorrect_count
  196. # log.info(['min loss', metrics_ary[METRICS_LOSS_NDX, benLabel_mask].min(), metrics_ary[METRICS_LOSS_NDX, malLabel_mask].min()])
  197. # log.info(['max loss', metrics_ary[METRICS_LOSS_NDX, benLabel_mask].max(), metrics_ary[METRICS_LOSS_NDX, malLabel_mask].max()])
  198. metrics_dict = {}
  199. metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
  200. metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, benLabel_mask].mean()
  201. metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, malLabel_mask].mean()
  202. metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_ary.shape[1] * 100
  203. metrics_dict['correct/ben'] = (benCorrect_count) / benLabel_count * 100
  204. metrics_dict['correct/mal'] = (malCorrect_count) / malLabel_count * 100
  205. log.info(
  206. ("E{} {:8} "
  207. + "{loss/all:.4f} loss, "
  208. + "{correct/all:-5.1f}% correct, "
  209. ).format(
  210. epoch_ndx,
  211. mode_str,
  212. **metrics_dict,
  213. )
  214. )
  215. log.info(
  216. ("E{} {:8} "
  217. + "{loss/ben:.4f} loss, "
  218. + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
  219. ).format(
  220. epoch_ndx,
  221. mode_str + '_ben',
  222. benCorrect_count=benCorrect_count,
  223. benLabel_count=benLabel_count,
  224. **metrics_dict,
  225. )
  226. )
  227. log.info(
  228. ("E{} {:8} "
  229. + "{loss/mal:.4f} loss, "
  230. + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
  231. ).format(
  232. epoch_ndx,
  233. mode_str + '_mal',
  234. malCorrect_count=malCorrect_count,
  235. malLabel_count=malLabel_count,
  236. **metrics_dict,
  237. )
  238. )
  239. writer = getattr(self, mode_str + '_writer')
  240. for key, value in metrics_dict.items():
  241. writer.add_scalar(key, value, self.totalTrainingSamples_count)
  242. writer.add_pr_curve(
  243. 'pr',
  244. metrics_ary[METRICS_LABEL_NDX],
  245. metrics_ary[METRICS_PRED_NDX],
  246. self.totalTrainingSamples_count,
  247. )
  248. bins = [x/50.0 for x in range(51)]
  249. benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
  250. malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
  251. if benHist_mask.any():
  252. writer.add_histogram(
  253. 'is_ben',
  254. metrics_ary[METRICS_PRED_NDX, benHist_mask],
  255. self.totalTrainingSamples_count,
  256. bins=bins,
  257. )
  258. if malHist_mask.any():
  259. writer.add_histogram(
  260. 'is_mal',
  261. metrics_ary[METRICS_PRED_NDX, malHist_mask],
  262. self.totalTrainingSamples_count,
  263. bins=bins,
  264. )
  265. # score = 1 \
  266. # + metrics_dict['pr/f1_score'] \
  267. # - metrics_dict['loss/mal'] * 0.01 \
  268. # - metrics_dict['loss/all'] * 0.0001
  269. #
  270. # return score
  271. # def logModelMetrics(self, model):
  272. # writer = getattr(self, 'trn_writer')
  273. #
  274. # model = getattr(model, 'module', model)
  275. #
  276. # for name, param in model.named_parameters():
  277. # if param.requires_grad:
  278. # min_data = float(param.data.min())
  279. # max_data = float(param.data.max())
  280. # max_extent = max(abs(min_data), abs(max_data))
  281. #
  282. # # bins = [x/50*max_extent for x in range(-50, 51)]
  283. #
  284. # try:
  285. # writer.add_histogram(
  286. # name.rsplit('.', 1)[-1] + '/' + name,
  287. # param.data.cpu().numpy(),
  288. # # metrics_ary[METRICS_PRED_NDX, benHist_mask],
  289. # self.totalTrainingSamples_count,
  290. # # bins=bins,
  291. # )
  292. # except Exception as e:
  293. # log.error([min_data, max_data])
  294. # raise
  295. if __name__ == '__main__':
  296. sys.exit(LunaTrainingApp().main() or 0)