training.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. import argparse
  2. import datetime
  3. import os
  4. import sys
  5. import numpy as np
  6. from torch.utils.tensorboard import SummaryWriter
  7. import torch
  8. import torch.nn as nn
  9. from torch.optim import SGD
  10. from torch.utils.data import DataLoader
  11. from util.util import enumerateWithEstimate
  12. from .dsets import LunaDataset
  13. from util.logconf import logging
  14. from .model import LunaModel
  15. log = logging.getLogger(__name__)
  16. # log.setLevel(logging.WARN)
  17. log.setLevel(logging.INFO)
  18. # log.setLevel(logging.DEBUG)
  19. # Used for computeBatchLoss and logMetrics to index into metrics_t/metrics_a
  20. METRICS_LABEL_NDX=0
  21. METRICS_PRED_NDX=1
  22. METRICS_LOSS_NDX=2
  23. METRICS_SIZE = 3
  24. class LunaTrainingApp:
  25. def __init__(self, sys_argv=None):
  26. if sys_argv is None:
  27. sys_argv = sys.argv[1:]
  28. parser = argparse.ArgumentParser()
  29. parser.add_argument('--batch-size',
  30. help='Batch size to use for training',
  31. default=32,
  32. type=int,
  33. )
  34. parser.add_argument('--num-workers',
  35. help='Number of worker processes for background data loading',
  36. default=8,
  37. type=int,
  38. )
  39. parser.add_argument('--epochs',
  40. help='Number of epochs to train for',
  41. default=1,
  42. type=int,
  43. )
  44. parser.add_argument('--balanced',
  45. help="Balance the training data to half positive, half negative.",
  46. action='store_true',
  47. default=False,
  48. )
  49. parser.add_argument('--augmented',
  50. help="Augment the training data.",
  51. action='store_true',
  52. default=False,
  53. )
  54. parser.add_argument('--augment-flip',
  55. help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
  56. action='store_true',
  57. default=False,
  58. )
  59. parser.add_argument('--augment-offset',
  60. help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
  61. action='store_true',
  62. default=False,
  63. )
  64. parser.add_argument('--augment-scale',
  65. help="Augment the training data by randomly increasing or decreasing the size of the candidate.",
  66. action='store_true',
  67. default=False,
  68. )
  69. parser.add_argument('--augment-rotate',
  70. help="Augment the training data by randomly rotating the data around the head-foot axis.",
  71. action='store_true',
  72. default=False,
  73. )
  74. parser.add_argument('--augment-noise',
  75. help="Augment the training data by randomly adding noise to the data.",
  76. action='store_true',
  77. default=False,
  78. )
  79. parser.add_argument('--tb-prefix',
  80. default='p2ch12',
  81. help="Data prefix to use for Tensorboard run. Defaults to chapter.",
  82. )
  83. parser.add_argument('comment',
  84. help="Comment suffix for Tensorboard run.",
  85. nargs='?',
  86. default='dlwpt',
  87. )
  88. self.cli_args = parser.parse_args(sys_argv)
  89. self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
  90. self.trn_writer = None
  91. self.val_writer = None
  92. self.totalTrainingSamples_count = 0
  93. self.augmentation_dict = {}
  94. if self.cli_args.augmented or self.cli_args.augment_flip:
  95. self.augmentation_dict['flip'] = True
  96. if self.cli_args.augmented or self.cli_args.augment_offset:
  97. self.augmentation_dict['offset'] = 0.1
  98. if self.cli_args.augmented or self.cli_args.augment_scale:
  99. self.augmentation_dict['scale'] = 0.2
  100. if self.cli_args.augmented or self.cli_args.augment_rotate:
  101. self.augmentation_dict['rotate'] = True
  102. if self.cli_args.augmented or self.cli_args.augment_noise:
  103. self.augmentation_dict['noise'] = 25.0
  104. self.use_cuda = torch.cuda.is_available()
  105. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  106. self.model = self.initModel()
  107. self.optimizer = self.initOptimizer()
  108. def initModel(self):
  109. model = LunaModel()
  110. if self.use_cuda:
  111. log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
  112. if torch.cuda.device_count() > 1:
  113. model = nn.DataParallel(model)
  114. model = model.to(self.device)
  115. return model
  116. def initOptimizer(self):
  117. return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
  118. # return Adam(self.model.parameters())
  119. def initTrainDl(self):
  120. train_ds = LunaDataset(
  121. val_stride=10,
  122. isValSet_bool=False,
  123. ratio_int=int(self.cli_args.balanced),
  124. augmentation_dict=self.augmentation_dict,
  125. )
  126. batch_size = self.cli_args.batch_size
  127. if self.use_cuda:
  128. batch_size *= torch.cuda.device_count()
  129. train_dl = DataLoader(
  130. train_ds,
  131. batch_size=batch_size,
  132. num_workers=self.cli_args.num_workers,
  133. pin_memory=self.use_cuda,
  134. )
  135. return train_dl
  136. def initValDl(self):
  137. val_ds = LunaDataset(
  138. val_stride=10,
  139. isValSet_bool=True,
  140. )
  141. batch_size = self.cli_args.batch_size
  142. if self.use_cuda:
  143. batch_size *= torch.cuda.device_count()
  144. val_dl = DataLoader(
  145. val_ds,
  146. batch_size=batch_size,
  147. num_workers=self.cli_args.num_workers,
  148. pin_memory=self.use_cuda,
  149. )
  150. return val_dl
  151. def initTensorboardWriters(self):
  152. if self.trn_writer is None:
  153. log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
  154. self.trn_writer = SummaryWriter(
  155. log_dir=log_dir + '-trn_cls-' + self.cli_args.comment)
  156. self.val_writer = SummaryWriter(
  157. log_dir=log_dir + '-val_cls-' + self.cli_args.comment)
  158. def main(self):
  159. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  160. train_dl = self.initTrainDl()
  161. val_dl = self.initValDl()
  162. for epoch_ndx in range(1, self.cli_args.epochs + 1):
  163. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  164. epoch_ndx,
  165. self.cli_args.epochs,
  166. len(train_dl),
  167. len(val_dl),
  168. self.cli_args.batch_size,
  169. (torch.cuda.device_count() if self.use_cuda else 1),
  170. ))
  171. trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
  172. self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
  173. valMetrics_t = self.doValidation(epoch_ndx, val_dl)
  174. self.logMetrics(epoch_ndx, 'val', valMetrics_t)
  175. if hasattr(self, 'trn_writer'):
  176. self.trn_writer.close()
  177. self.val_writer.close()
  178. def doTraining(self, epoch_ndx, train_dl):
  179. self.model.train()
  180. train_dl.dataset.shuffleSamples()
  181. trnMetrics_g = torch.zeros(
  182. METRICS_SIZE,
  183. len(train_dl.dataset),
  184. device=self.device,
  185. )
  186. batch_iter = enumerateWithEstimate(
  187. train_dl,
  188. "E{} Training".format(epoch_ndx),
  189. start_ndx=train_dl.num_workers,
  190. )
  191. for batch_ndx, batch_tup in batch_iter:
  192. self.optimizer.zero_grad()
  193. loss_var = self.computeBatchLoss(
  194. batch_ndx,
  195. batch_tup,
  196. train_dl.batch_size,
  197. trnMetrics_g,
  198. )
  199. loss_var.backward()
  200. self.optimizer.step()
  201. self.totalTrainingSamples_count += len(train_dl.dataset)
  202. return trnMetrics_g.to('cpu')
  203. def doValidation(self, epoch_ndx, val_dl):
  204. with torch.no_grad():
  205. self.model.eval()
  206. valMetrics_g = torch.zeros(
  207. METRICS_SIZE,
  208. len(val_dl.dataset),
  209. device=self.device,
  210. )
  211. batch_iter = enumerateWithEstimate(
  212. val_dl,
  213. "E{} Validation ".format(epoch_ndx),
  214. start_ndx=val_dl.num_workers,
  215. )
  216. for batch_ndx, batch_tup in batch_iter:
  217. self.computeBatchLoss(
  218. batch_ndx,
  219. batch_tup,
  220. val_dl.batch_size,
  221. valMetrics_g,
  222. )
  223. return valMetrics_g.to('cpu')
  224. def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):
  225. input_t, label_t, _series_list, _center_list = batch_tup
  226. input_g = input_t.to(self.device, non_blocking=True)
  227. label_g = label_t.to(self.device, non_blocking=True)
  228. logits_g, probability_g = self.model(input_g)
  229. loss_func = nn.CrossEntropyLoss(reduction='none')
  230. loss_g = loss_func(
  231. logits_g,
  232. label_g[:,1],
  233. )
  234. start_ndx = batch_ndx * batch_size
  235. end_ndx = start_ndx + label_t.size(0)
  236. metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_g[:,1]
  237. metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = probability_g[:,1]
  238. metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_g
  239. return loss_g.mean()
  240. def logMetrics(
  241. self,
  242. epoch_ndx,
  243. mode_str,
  244. metrics_t,
  245. classificationThreshold=0.5,
  246. ):
  247. self.initTensorboardWriters()
  248. log.info("E{} {}".format(
  249. epoch_ndx,
  250. type(self).__name__,
  251. ))
  252. negLabel_mask = metrics_t[METRICS_LABEL_NDX] <= classificationThreshold
  253. negPred_mask = metrics_t[METRICS_PRED_NDX] <= classificationThreshold
  254. posLabel_mask = ~negLabel_mask
  255. posPred_mask = ~negPred_mask
  256. neg_count = int(negLabel_mask.sum())
  257. pos_count = int(posLabel_mask.sum())
  258. trueNeg_count = neg_correct = int((negLabel_mask & negPred_mask).sum())
  259. truePos_count = pos_correct = int((posLabel_mask & posPred_mask).sum())
  260. falsePos_count = neg_count - neg_correct
  261. falseNeg_count = pos_count - pos_correct
  262. metrics_dict = {}
  263. metrics_dict['loss/all'] = metrics_t[METRICS_LOSS_NDX].mean()
  264. metrics_dict['loss/neg'] = metrics_t[METRICS_LOSS_NDX, negLabel_mask].mean()
  265. metrics_dict['loss/pos'] = metrics_t[METRICS_LOSS_NDX, posLabel_mask].mean()
  266. metrics_dict['correct/all'] = (pos_correct + neg_correct) / metrics_t.shape[1] * 100
  267. metrics_dict['correct/neg'] = (neg_correct) / neg_count * 100
  268. metrics_dict['correct/pos'] = (pos_correct) / pos_count * 100
  269. precision = metrics_dict['pr/precision'] = \
  270. truePos_count / np.float32(truePos_count + falsePos_count)
  271. recall = metrics_dict['pr/recall'] = \
  272. truePos_count / np.float32(truePos_count + falseNeg_count)
  273. metrics_dict['pr/f1_score'] = \
  274. 2 * (precision * recall) / (precision + recall)
  275. log.info(
  276. ("E{} {:8} {loss/all:.4f} loss, "
  277. + "{correct/all:-5.1f}% correct, "
  278. + "{pr/precision:.4f} precision, "
  279. + "{pr/recall:.4f} recall, "
  280. + "{pr/f1_score:.4f} f1 score"
  281. ).format(
  282. epoch_ndx,
  283. mode_str,
  284. **metrics_dict,
  285. )
  286. )
  287. log.info(
  288. ("E{} {:8} {loss/neg:.4f} loss, "
  289. + "{correct/neg:-5.1f}% correct ({neg_correct:} of {neg_count:})"
  290. ).format(
  291. epoch_ndx,
  292. mode_str + '_neg',
  293. neg_correct=neg_correct,
  294. neg_count=neg_count,
  295. **metrics_dict,
  296. )
  297. )
  298. log.info(
  299. ("E{} {:8} {loss/pos:.4f} loss, "
  300. + "{correct/pos:-5.1f}% correct ({pos_correct:} of {pos_count:})"
  301. ).format(
  302. epoch_ndx,
  303. mode_str + '_pos',
  304. pos_correct=pos_correct,
  305. pos_count=pos_count,
  306. **metrics_dict,
  307. )
  308. )
  309. writer = getattr(self, mode_str + '_writer')
  310. for key, value in metrics_dict.items():
  311. writer.add_scalar(key, value, self.totalTrainingSamples_count)
  312. writer.add_pr_curve(
  313. 'pr',
  314. metrics_t[METRICS_LABEL_NDX],
  315. metrics_t[METRICS_PRED_NDX],
  316. self.totalTrainingSamples_count,
  317. )
  318. bins = [x/50.0 for x in range(51)]
  319. negHist_mask = negLabel_mask & (metrics_t[METRICS_PRED_NDX] > 0.01)
  320. posHist_mask = posLabel_mask & (metrics_t[METRICS_PRED_NDX] < 0.99)
  321. if negHist_mask.any():
  322. writer.add_histogram(
  323. 'is_neg',
  324. metrics_t[METRICS_PRED_NDX, negHist_mask],
  325. self.totalTrainingSamples_count,
  326. bins=bins,
  327. )
  328. if posHist_mask.any():
  329. writer.add_histogram(
  330. 'is_pos',
  331. metrics_t[METRICS_PRED_NDX, posHist_mask],
  332. self.totalTrainingSamples_count,
  333. bins=bins,
  334. )
  335. # score = 1 \
  336. # + metrics_dict['pr/f1_score'] \
  337. # - metrics_dict['loss/mal'] * 0.01 \
  338. # - metrics_dict['loss/all'] * 0.0001
  339. #
  340. # return score
  341. # def logModelMetrics(self, model):
  342. # writer = getattr(self, 'trn_writer')
  343. #
  344. # model = getattr(model, 'module', model)
  345. #
  346. # for name, param in model.named_parameters():
  347. # if param.requires_grad:
  348. # min_data = float(param.data.min())
  349. # max_data = float(param.data.max())
  350. # max_extent = max(abs(min_data), abs(max_data))
  351. #
  352. # # bins = [x/50*max_extent for x in range(-50, 51)]
  353. #
  354. # try:
  355. # writer.add_histogram(
  356. # name.rsplit('.', 1)[-1] + '/' + name,
  357. # param.data.cpu().numpy(),
  358. # # metrics_a[METRICS_PRED_NDX, negHist_mask],
  359. # self.totalTrainingSamples_count,
  360. # # bins=bins,
  361. # )
  362. # except Exception as e:
  363. # log.error([min_data, max_data])
  364. # raise
  365. if __name__ == '__main__':
  366. LunaTrainingApp().main()