training.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. import argparse
  2. import datetime
  3. import os
  4. import sys
  5. import numpy as np
  6. from torch.utils.tensorboard import SummaryWriter
  7. import torch
  8. import torch.nn as nn
  9. from torch.optim import SGD, Adam
  10. from torch.utils.data import DataLoader
  11. from util.util import enumerateWithEstimate
  12. from .dsets import LunaDataset
  13. from util.logconf import logging
  14. from .model import LunaModel
  15. log = logging.getLogger(__name__)
  16. # log.setLevel(logging.WARN)
  17. log.setLevel(logging.INFO)
  18. log.setLevel(logging.DEBUG)
  19. # Used for computeBatchLoss and logMetrics to index into metrics_t/metrics_a
  20. METRICS_LABEL_NDX=0
  21. METRICS_PRED_NDX=1
  22. METRICS_LOSS_NDX=2
  23. METRICS_SIZE = 3
  24. class LunaTrainingApp:
  25. def __init__(self, sys_argv=None):
  26. if sys_argv is None:
  27. sys_argv = sys.argv[1:]
  28. parser = argparse.ArgumentParser()
  29. parser.add_argument('--num-workers',
  30. help='Number of worker processes for background data loading',
  31. default=8,
  32. type=int,
  33. )
  34. parser.add_argument('--batch-size',
  35. help='Batch size to use for training',
  36. default=32,
  37. type=int,
  38. )
  39. parser.add_argument('--epochs',
  40. help='Number of epochs to train for',
  41. default=1,
  42. type=int,
  43. )
  44. parser.add_argument('--tb-prefix',
  45. default='p2ch11',
  46. help="Data prefix to use for Tensorboard run. Defaults to chapter.",
  47. )
  48. parser.add_argument('comment',
  49. help="Comment suffix for Tensorboard run.",
  50. nargs='?',
  51. default='dwlpt',
  52. )
  53. self.cli_args = parser.parse_args(sys_argv)
  54. self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
  55. self.trn_writer = None
  56. self.val_writer = None
  57. self.totalTrainingSamples_count = 0
  58. self.use_cuda = torch.cuda.is_available()
  59. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  60. self.model = self.initModel()
  61. self.optimizer = self.initOptimizer()
  62. def initModel(self):
  63. model = LunaModel()
  64. if self.use_cuda:
  65. log.info("Using CUDA with {} devices.".format(torch.cuda.device_count()))
  66. if torch.cuda.device_count() > 1:
  67. model = nn.DataParallel(model)
  68. model = model.to(self.device)
  69. return model
  70. def initOptimizer(self):
  71. return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
  72. # return Adam(self.model.parameters())
  73. def initTrainDl(self):
  74. train_ds = LunaDataset(
  75. val_stride=10,
  76. isValSet_bool=False,
  77. )
  78. batch_size = self.cli_args.batch_size
  79. if self.use_cuda:
  80. batch_size *= torch.cuda.device_count()
  81. train_dl = DataLoader(
  82. train_ds,
  83. batch_size=batch_size,
  84. num_workers=self.cli_args.num_workers,
  85. pin_memory=self.use_cuda,
  86. )
  87. return train_dl
  88. def initValDl(self):
  89. val_ds = LunaDataset(
  90. val_stride=10,
  91. isValSet_bool=True,
  92. )
  93. batch_size = self.cli_args.batch_size
  94. if self.use_cuda:
  95. batch_size *= torch.cuda.device_count()
  96. val_dl = DataLoader(
  97. val_ds,
  98. batch_size=batch_size,
  99. num_workers=self.cli_args.num_workers,
  100. pin_memory=self.use_cuda,
  101. )
  102. return val_dl
  103. def initTensorboardWriters(self):
  104. if self.trn_writer is None:
  105. log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
  106. self.trn_writer = SummaryWriter(
  107. log_dir=log_dir + '-trn_cls-' + self.cli_args.comment)
  108. self.val_writer = SummaryWriter(
  109. log_dir=log_dir + '-val_cls-' + self.cli_args.comment)
  110. def main(self):
  111. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  112. train_dl = self.initTrainDl()
  113. val_dl = self.initValDl()
  114. self.initTensorboardWriters()
  115. for epoch_ndx in range(1, self.cli_args.epochs + 1):
  116. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  117. epoch_ndx,
  118. self.cli_args.epochs,
  119. len(train_dl),
  120. len(val_dl),
  121. self.cli_args.batch_size,
  122. (torch.cuda.device_count() if self.use_cuda else 1),
  123. ))
  124. trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
  125. self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
  126. valMetrics_t = self.doValidation(epoch_ndx, val_dl)
  127. self.logMetrics(epoch_ndx, 'val', valMetrics_t)
  128. if hasattr(self, 'trn_writer'):
  129. self.trn_writer.close()
  130. self.val_writer.close()
  131. def doTraining(self, epoch_ndx, train_dl):
  132. self.model.train()
  133. trnMetrics_g = torch.zeros(
  134. METRICS_SIZE,
  135. len(train_dl.dataset),
  136. device=self.device,
  137. )
  138. batch_iter = enumerateWithEstimate(
  139. train_dl,
  140. "E{} Training".format(epoch_ndx),
  141. start_ndx=train_dl.num_workers,
  142. )
  143. for batch_ndx, batch_tup in batch_iter:
  144. self.optimizer.zero_grad()
  145. loss_var = self.computeBatchLoss(
  146. batch_ndx,
  147. batch_tup,
  148. train_dl.batch_size,
  149. trnMetrics_g
  150. )
  151. loss_var.backward()
  152. self.optimizer.step()
  153. # # This is for adding the model graph to TensorBoard.
  154. # if epoch_ndx == 1 and batch_ndx == 0:
  155. # with torch.no_grad():
  156. # model = LunaModel()
  157. # self.trn_writer.add_graph(model, batch_tup[0], verbose=True)
  158. # self.trn_writer.close()
  159. self.totalTrainingSamples_count += len(train_dl.dataset)
  160. return trnMetrics_g.to('cpu')
  161. def doValidation(self, epoch_ndx, val_dl):
  162. with torch.no_grad():
  163. self.model.eval()
  164. valMetrics_g = torch.zeros(
  165. METRICS_SIZE,
  166. len(val_dl.dataset),
  167. device=self.device,
  168. )
  169. batch_iter = enumerateWithEstimate(
  170. val_dl,
  171. "E{} Validation ".format(epoch_ndx),
  172. start_ndx=val_dl.num_workers,
  173. )
  174. for batch_ndx, batch_tup in batch_iter:
  175. self.computeBatchLoss(
  176. batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)
  177. return valMetrics_g.to('cpu')
  178. def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):
  179. input_t, label_t, _series_list, _center_list = batch_tup
  180. input_g = input_t.to(self.device, non_blocking=True)
  181. label_g = label_t.to(self.device, non_blocking=True)
  182. logits_g, probability_g = self.model(input_g)
  183. loss_func = nn.CrossEntropyLoss(reduction='none')
  184. loss_g = loss_func(
  185. logits_g,
  186. label_g[:,1],
  187. )
  188. start_ndx = batch_ndx * batch_size
  189. end_ndx = start_ndx + label_t.size(0)
  190. metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = \
  191. label_g[:,1].detach()
  192. metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = \
  193. probability_g[:,1].detach()
  194. metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = \
  195. loss_g.detach()
  196. return loss_g.mean()
  197. def logMetrics(
  198. self,
  199. epoch_ndx,
  200. mode_str,
  201. metrics_t,
  202. classificationThreshold=0.5,
  203. ):
  204. self.initTensorboardWriters()
  205. log.info("E{} {}".format(
  206. epoch_ndx,
  207. type(self).__name__,
  208. ))
  209. negLabel_mask = metrics_t[METRICS_LABEL_NDX] <= classificationThreshold
  210. negPred_mask = metrics_t[METRICS_PRED_NDX] <= classificationThreshold
  211. posLabel_mask = ~negLabel_mask
  212. posPred_mask = ~negPred_mask
  213. neg_count = int(negLabel_mask.sum())
  214. pos_count = int(posLabel_mask.sum())
  215. neg_correct = int((negLabel_mask & negPred_mask).sum())
  216. pos_correct = int((posLabel_mask & posPred_mask).sum())
  217. metrics_dict = {}
  218. metrics_dict['loss/all'] = \
  219. metrics_t[METRICS_LOSS_NDX].mean()
  220. metrics_dict['loss/neg'] = \
  221. metrics_t[METRICS_LOSS_NDX, negLabel_mask].mean()
  222. metrics_dict['loss/pos'] = \
  223. metrics_t[METRICS_LOSS_NDX, posLabel_mask].mean()
  224. metrics_dict['correct/all'] = \
  225. (pos_correct + neg_correct) / np.float32(metrics_t.shape[1]) * 100
  226. metrics_dict['correct/neg'] = (neg_correct) / np.float32(neg_count) * 100
  227. metrics_dict['correct/pos'] = (pos_correct) / np.float32(pos_count) * 100
  228. log.info(
  229. ("E{} {:8} {loss/all:.4f} loss, "
  230. + "{correct/all:-5.1f}% correct, "
  231. ).format(
  232. epoch_ndx,
  233. mode_str,
  234. **metrics_dict,
  235. )
  236. )
  237. log.info(
  238. ("E{} {:8} {loss/neg:.4f} loss, "
  239. + "{correct/neg:-5.1f}% correct ({neg_correct:} of {neg_count:})"
  240. ).format(
  241. epoch_ndx,
  242. mode_str + '_neg',
  243. neg_correct=neg_correct,
  244. neg_count=neg_count,
  245. **metrics_dict,
  246. )
  247. )
  248. log.info(
  249. ("E{} {:8} {loss/pos:.4f} loss, "
  250. + "{correct/pos:-5.1f}% correct ({pos_correct:} of {pos_count:})"
  251. ).format(
  252. epoch_ndx,
  253. mode_str + '_pos',
  254. pos_correct=pos_correct,
  255. pos_count=pos_count,
  256. **metrics_dict,
  257. )
  258. )
  259. writer = getattr(self, mode_str + '_writer')
  260. for key, value in metrics_dict.items():
  261. writer.add_scalar(key, value, self.totalTrainingSamples_count)
  262. writer.add_pr_curve(
  263. 'pr',
  264. metrics_t[METRICS_LABEL_NDX],
  265. metrics_t[METRICS_PRED_NDX],
  266. self.totalTrainingSamples_count,
  267. )
  268. bins = [x/50.0 for x in range(51)]
  269. negHist_mask = negLabel_mask & (metrics_t[METRICS_PRED_NDX] > 0.01)
  270. posHist_mask = posLabel_mask & (metrics_t[METRICS_PRED_NDX] < 0.99)
  271. if negHist_mask.any():
  272. writer.add_histogram(
  273. 'is_neg',
  274. metrics_t[METRICS_PRED_NDX, negHist_mask],
  275. self.totalTrainingSamples_count,
  276. bins=bins,
  277. )
  278. if posHist_mask.any():
  279. writer.add_histogram(
  280. 'is_pos',
  281. metrics_t[METRICS_PRED_NDX, posHist_mask],
  282. self.totalTrainingSamples_count,
  283. bins=bins,
  284. )
  285. # score = 1 \
  286. # + metrics_dict['pr/f1_score'] \
  287. # - metrics_dict['loss/mal'] * 0.01 \
  288. # - metrics_dict['loss/all'] * 0.0001
  289. #
  290. # return score
  291. # def logModelMetrics(self, model):
  292. # writer = getattr(self, 'trn_writer')
  293. #
  294. # model = getattr(model, 'module', model)
  295. #
  296. # for name, param in model.named_parameters():
  297. # if param.requires_grad:
  298. # min_data = float(param.data.min())
  299. # max_data = float(param.data.max())
  300. # max_extent = max(abs(min_data), abs(max_data))
  301. #
  302. # # bins = [x/50*max_extent for x in range(-50, 51)]
  303. #
  304. # try:
  305. # writer.add_histogram(
  306. # name.rsplit('.', 1)[-1] + '/' + name,
  307. # param.data.cpu().numpy(),
  308. # # metrics_a[METRICS_PRED_NDX, negHist_mask],
  309. # self.totalTrainingSamples_count,
  310. # # bins=bins,
  311. # )
  312. # except Exception as e:
  313. # log.error([min_data, max_data])
  314. # raise
  315. if __name__ == '__main__':
  316. LunaTrainingApp().main()