train_cls.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. import argparse
  2. import datetime
  3. import os
  4. import sys
  5. import numpy as np
  6. from tensorboardX import SummaryWriter
  7. import torch
  8. import torch.nn as nn
  9. from torch.optim import SGD, Adam
  10. from torch.utils.data import DataLoader
  11. from util.util import enumerateWithEstimate
  12. from .dsets import LunaDataset
  13. from .model_cls import LunaModel
  14. from util.logconf import logging
  15. log = logging.getLogger(__name__)
  16. # log.setLevel(logging.WARN)
  17. # log.setLevel(logging.INFO)
  18. log.setLevel(logging.DEBUG)
  19. # Used for computeBatchLoss and logMetrics to index into metrics_tensor/metrics_ary
  20. METRICS_LABEL_NDX=0
  21. METRICS_PRED_NDX=1
  22. METRICS_LOSS_NDX=2
  23. METRICS_SIZE = 3
  24. class LunaTrainingApp(object):
  25. def __init__(self, sys_argv=None):
  26. if sys_argv is None:
  27. sys_argv = sys.argv[1:]
  28. parser = argparse.ArgumentParser()
  29. parser.add_argument('--batch-size',
  30. help='Batch size to use for training',
  31. default=32,
  32. type=int,
  33. )
  34. parser.add_argument('--num-workers',
  35. help='Number of worker processes for background data loading',
  36. default=8,
  37. type=int,
  38. )
  39. parser.add_argument('--epochs',
  40. help='Number of epochs to train for',
  41. default=1,
  42. type=int,
  43. )
  44. parser.add_argument('--balanced',
  45. help="Balance the training data to half benign, half malignant.",
  46. action='store_true',
  47. default=False,
  48. )
  49. parser.add_argument('--augmented',
  50. help="Augment the training data.",
  51. action='store_true',
  52. default=False,
  53. )
  54. parser.add_argument('--augment-flip',
  55. help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
  56. action='store_true',
  57. default=False,
  58. )
  59. parser.add_argument('--augment-offset',
  60. help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
  61. action='store_true',
  62. default=False,
  63. )
  64. parser.add_argument('--augment-scale',
  65. help="Augment the training data by randomly increasing or decreasing the size of the nodule.",
  66. action='store_true',
  67. default=False,
  68. )
  69. parser.add_argument('--augment-rotate',
  70. help="Augment the training data by randomly rotating the data around the head-foot axis.",
  71. action='store_true',
  72. default=False,
  73. )
  74. parser.add_argument('--augment-noise',
  75. help="Augment the training data by randomly adding noise to the data.",
  76. action='store_true',
  77. default=False,
  78. )
  79. parser.add_argument('--tb-prefix',
  80. default='p2ch12',
  81. help="Data prefix to use for Tensorboard run. Defaults to chapter.",
  82. )
  83. parser.add_argument('comment',
  84. help="Comment suffix for Tensorboard run.",
  85. nargs='?',
  86. default='none',
  87. )
  88. self.cli_args = parser.parse_args(sys_argv)
  89. self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
  90. self.totalTrainingSamples_count = 0
  91. self.trn_writer = None
  92. self.tst_writer = None
  93. self.augmentation_dict = {}
  94. if self.cli_args.augmented or self.cli_args.augment_flip:
  95. self.augmentation_dict['flip'] = True
  96. if self.cli_args.augmented or self.cli_args.augment_offset:
  97. self.augmentation_dict['offset'] = 0.1
  98. if self.cli_args.augmented or self.cli_args.augment_scale:
  99. self.augmentation_dict['scale'] = 0.2
  100. if self.cli_args.augmented or self.cli_args.augment_rotate:
  101. self.augmentation_dict['rotate'] = True
  102. if self.cli_args.augmented or self.cli_args.augment_noise:
  103. self.augmentation_dict['noise'] = 25.0
  104. self.use_cuda = torch.cuda.is_available()
  105. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  106. self.model = self.initModel()
  107. self.optimizer = self.initOptimizer()
  108. def initModel(self):
  109. model = LunaModel()
  110. if self.use_cuda:
  111. if torch.cuda.device_count() > 1:
  112. model = nn.DataParallel(model)
  113. model = model.to(self.device)
  114. return model
  115. def initOptimizer(self):
  116. return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
  117. # return Adam(self.model.parameters())
  118. def initTrainDl(self):
  119. train_ds = LunaDataset(
  120. test_stride=10,
  121. isTestSet_bool=False,
  122. ratio_int=int(self.cli_args.balanced),
  123. augmentation_dict=self.augmentation_dict,
  124. )
  125. train_dl = DataLoader(
  126. train_ds,
  127. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  128. num_workers=self.cli_args.num_workers,
  129. pin_memory=self.use_cuda,
  130. )
  131. return train_dl
  132. def initTestDl(self):
  133. test_ds = LunaDataset(
  134. test_stride=10,
  135. isTestSet_bool=True,
  136. )
  137. test_dl = DataLoader(
  138. test_ds,
  139. batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
  140. num_workers=self.cli_args.num_workers,
  141. pin_memory=self.use_cuda,
  142. )
  143. return test_dl
  144. def initTensorboardWriters(self):
  145. if self.trn_writer is None:
  146. log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
  147. self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_cls_' + self.cli_args.comment)
  148. self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_cls_' + self.cli_args.comment)
  149. # eng::tb_writer[]
  150. def main(self):
  151. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  152. train_dl = self.initTrainDl()
  153. test_dl = self.initTestDl()
  154. best_score = 0.0
  155. for epoch_ndx in range(1, self.cli_args.epochs + 1):
  156. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  157. epoch_ndx,
  158. self.cli_args.epochs,
  159. len(train_dl),
  160. len(test_dl),
  161. self.cli_args.batch_size,
  162. (torch.cuda.device_count() if self.use_cuda else 1),
  163. ))
  164. trnMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
  165. self.logMetrics(epoch_ndx, 'trn', trnMetrics_tensor)
  166. tstMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
  167. score = self.logMetrics(epoch_ndx, 'tst', tstMetrics_tensor)
  168. best_score = max(score, best_score)
  169. self.saveModel('cls', epoch_ndx, score == best_score)
  170. if hasattr(self, 'trn_writer'):
  171. self.trn_writer.close()
  172. self.tst_writer.close()
  173. def doTraining(self, epoch_ndx, train_dl):
  174. self.model.train()
  175. train_dl.dataset.shuffleSamples()
  176. trainingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
  177. batch_iter = enumerateWithEstimate(
  178. train_dl,
  179. "E{} Training".format(epoch_ndx),
  180. start_ndx=train_dl.num_workers,
  181. )
  182. for batch_ndx, batch_tup in batch_iter:
  183. self.optimizer.zero_grad()
  184. loss_var = self.computeBatchLoss(
  185. batch_ndx,
  186. batch_tup,
  187. train_dl.batch_size,
  188. trainingMetrics_devtensor
  189. )
  190. loss_var.backward()
  191. self.optimizer.step()
  192. del loss_var
  193. self.totalTrainingSamples_count += trainingMetrics_devtensor.size(1)
  194. return trainingMetrics_devtensor.to('cpu')
  195. def doTesting(self, epoch_ndx, test_dl):
  196. with torch.no_grad():
  197. self.model.eval()
  198. testingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset)).to(self.device)
  199. batch_iter = enumerateWithEstimate(
  200. test_dl,
  201. "E{} Testing ".format(epoch_ndx),
  202. start_ndx=test_dl.num_workers,
  203. )
  204. for batch_ndx, batch_tup in batch_iter:
  205. self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_devtensor)
  206. return testingMetrics_devtensor.to('cpu')
  207. def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_devtensor):
  208. input_tensor, label_tensor, _series_list, _center_list = batch_tup
  209. input_devtensor = input_tensor.to(self.device, non_blocking=True)
  210. label_devtensor = label_tensor.to(self.device, non_blocking=True)
  211. logits_devtensor, probability_devtensor = self.model(input_devtensor)
  212. # log.debug(['input', input_devtensor.min().item(), input_devtensor.max().item()])
  213. # log.debug(['label', label_devtensor.min().item(), label_devtensor.max().item()])
  214. # log.debug(['logits', logits_devtensor.min().item(), logits_devtensor.max().item()])
  215. # log.debug(['probability', probability_devtensor.min().item(), probability_devtensor.max().item()])
  216. loss_func = nn.CrossEntropyLoss(reduction='none')
  217. loss_devtensor = loss_func(logits_devtensor, label_devtensor[:,1])
  218. # log.debug(['loss', loss_devtensor.min().item(), loss_devtensor.max().item()])
  219. start_ndx = batch_ndx * batch_size
  220. end_ndx = start_ndx + label_tensor.size(0)
  221. metrics_devtensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_devtensor[:,1]
  222. metrics_devtensor[METRICS_PRED_NDX, start_ndx:end_ndx] = probability_devtensor[:,1]
  223. metrics_devtensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor
  224. return loss_devtensor.mean()
  225. def logMetrics(
  226. self,
  227. epoch_ndx,
  228. mode_str,
  229. metrics_tensor,
  230. ):
  231. self.initTensorboardWriters()
  232. log.info("E{} {}".format(
  233. epoch_ndx,
  234. type(self).__name__,
  235. ))
  236. metrics_ary = metrics_tensor.cpu().detach().numpy()
  237. # assert np.isfinite(metrics_ary).all()
  238. benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= 0.5
  239. benPred_mask = metrics_ary[METRICS_PRED_NDX] <= 0.5
  240. malLabel_mask = ~benLabel_mask
  241. malPred_mask = ~benPred_mask
  242. benLabel_count = benLabel_mask.sum()
  243. malLabel_count = malLabel_mask.sum()
  244. trueNeg_count = benCorrect_count = (benLabel_mask & benPred_mask).sum()
  245. truePos_count = malCorrect_count = (malLabel_mask & malPred_mask).sum()
  246. falsePos_count = benLabel_count - benCorrect_count
  247. falseNeg_count = malLabel_count - malCorrect_count
  248. metrics_dict = {}
  249. metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
  250. metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, benLabel_mask].mean()
  251. metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, malLabel_mask].mean()
  252. metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_ary.shape[1] * 100
  253. metrics_dict['correct/ben'] = (benCorrect_count) / benLabel_count * 100
  254. metrics_dict['correct/mal'] = (malCorrect_count) / malLabel_count * 100
  255. precision = metrics_dict['pr/precision'] = truePos_count / (truePos_count + falsePos_count)
  256. recall = metrics_dict['pr/recall'] = truePos_count / (truePos_count + falseNeg_count)
  257. metrics_dict['pr/f1_score'] = 2 * (precision * recall) / (precision + recall)
  258. log.info(
  259. ("E{} {:8} "
  260. + "{loss/all:.4f} loss, "
  261. + "{correct/all:-5.1f}% correct, "
  262. + "{pr/precision:.4f} precision, "
  263. + "{pr/recall:.4f} recall, "
  264. + "{pr/f1_score:.4f} f1 score"
  265. ).format(
  266. epoch_ndx,
  267. mode_str,
  268. **metrics_dict,
  269. )
  270. )
  271. log.info(
  272. ("E{} {:8} "
  273. + "{loss/ben:.4f} loss, "
  274. + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
  275. ).format(
  276. epoch_ndx,
  277. mode_str + '_ben',
  278. benCorrect_count=benCorrect_count,
  279. benLabel_count=benLabel_count,
  280. **metrics_dict,
  281. )
  282. )
  283. log.info(
  284. ("E{} {:8} "
  285. + "{loss/mal:.4f} loss, "
  286. + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
  287. ).format(
  288. epoch_ndx,
  289. mode_str + '_mal',
  290. malCorrect_count=malCorrect_count,
  291. malLabel_count=malLabel_count,
  292. **metrics_dict,
  293. )
  294. )
  295. writer = getattr(self, mode_str + '_writer')
  296. for key, value in metrics_dict.items():
  297. writer.add_scalar(key, value, self.totalTrainingSamples_count)
  298. writer.add_pr_curve(
  299. 'pr',
  300. metrics_ary[METRICS_LABEL_NDX],
  301. metrics_ary[METRICS_PRED_NDX],
  302. self.totalTrainingSamples_count,
  303. )
  304. bins = [x/50.0 for x in range(51)]
  305. benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
  306. malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
  307. if benHist_mask.any():
  308. writer.add_histogram(
  309. 'is_ben',
  310. metrics_ary[METRICS_PRED_NDX, benHist_mask],
  311. self.totalTrainingSamples_count,
  312. bins=bins,
  313. )
  314. if malHist_mask.any():
  315. writer.add_histogram(
  316. 'is_mal',
  317. metrics_ary[METRICS_PRED_NDX, malHist_mask],
  318. self.totalTrainingSamples_count,
  319. bins=bins,
  320. )
  321. score = 1 \
  322. + metrics_dict['pr/f1_score'] \
  323. - metrics_dict['loss/mal'] * 0.01 \
  324. - metrics_dict['loss/all'] * 0.0001
  325. return score
  326. def saveModel(self, type_str, epoch_ndx, isBest=False):
  327. file_path = os.path.join(
  328. 'data-unversioned',
  329. 'part2',
  330. 'models',
  331. self.cli_args.tb_prefix,
  332. '{}_{}_{}.{}.state'.format(
  333. type_str,
  334. self.time_str,
  335. self.cli_args.comment,
  336. self.totalTrainingSamples_count,
  337. )
  338. )
  339. os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
  340. model = self.model
  341. if hasattr(model, 'module'):
  342. model = model.module
  343. state = {
  344. 'model_state': model.state_dict(),
  345. 'model_name': type(model).__name__,
  346. 'optimizer_state' : self.optimizer.state_dict(),
  347. 'optimizer_name': type(self.optimizer).__name__,
  348. 'epoch': epoch_ndx,
  349. 'totalTrainingSamples_count': self.totalTrainingSamples_count,
  350. # 'resumed_from': self.cli_args.resume,
  351. }
  352. torch.save(state, file_path)
  353. log.debug("Saved model params to {}".format(file_path))
  354. if isBest:
  355. file_path = os.path.join(
  356. 'data-unversioned',
  357. 'part2',
  358. 'models',
  359. self.cli_args.tb_prefix,
  360. '{}_{}_{}.{}.state'.format(
  361. type_str,
  362. self.time_str,
  363. self.cli_args.comment,
  364. 'best',
  365. )
  366. )
  367. torch.save(state, file_path)
  368. log.debug("Saved model params to {}".format(file_path))
  369. if __name__ == '__main__':
  370. sys.exit(LunaTrainingApp().main() or 0)