training.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. import argparse
  2. import datetime
  3. import hashlib
  4. import os
  5. import shutil
  6. import sys
  7. import numpy as np
  8. from matplotlib import pyplot
  9. from torch.utils.tensorboard import SummaryWriter
  10. import torch
  11. import torch.nn as nn
  12. from torch.optim import SGD, Adam
  13. from torch.utils.data import DataLoader
  14. import p2ch14.dsets
  15. import p2ch14.model
  16. from util.util import enumerateWithEstimate
  17. from util.logconf import logging
  18. log = logging.getLogger(__name__)
  19. # log.setLevel(logging.WARN)
  20. log.setLevel(logging.INFO)
  21. log.setLevel(logging.DEBUG)
  22. # Used for computeBatchLoss and logMetrics to index into metrics_t/metrics_a
  23. METRICS_LABEL_NDX=0
  24. METRICS_PRED_NDX=1
  25. METRICS_PRED_P_NDX=2
  26. METRICS_LOSS_NDX=3
  27. METRICS_SIZE = 4
  28. class ClassificationTrainingApp:
  29. def __init__(self, sys_argv=None):
  30. if sys_argv is None:
  31. sys_argv = sys.argv[1:]
  32. parser = argparse.ArgumentParser()
  33. parser.add_argument('--batch-size',
  34. help='Batch size to use for training',
  35. default=24,
  36. type=int,
  37. )
  38. parser.add_argument('--num-workers',
  39. help='Number of worker processes for background data loading',
  40. default=8,
  41. type=int,
  42. )
  43. parser.add_argument('--epochs',
  44. help='Number of epochs to train for',
  45. default=1,
  46. type=int,
  47. )
  48. parser.add_argument('--dataset',
  49. help="What to dataset to feed the model.",
  50. action='store',
  51. default='LunaDataset',
  52. )
  53. parser.add_argument('--model',
  54. help="What to model class name to use.",
  55. action='store',
  56. default='LunaModel',
  57. )
  58. parser.add_argument('--malignant',
  59. help="Train the model to classify nodules as benign or malignant.",
  60. action='store_true',
  61. default=False,
  62. )
  63. parser.add_argument('--finetune',
  64. help="Start finetuning from this model.",
  65. default='',
  66. )
  67. parser.add_argument('--finetune-depth',
  68. help="Number of blocks (counted from the head) to include in finetuning",
  69. type=int,
  70. default=1,
  71. )
  72. parser.add_argument('--tb-prefix',
  73. default='p2ch14',
  74. help="Data prefix to use for Tensorboard run. Defaults to chapter.",
  75. )
  76. parser.add_argument('comment',
  77. help="Comment suffix for Tensorboard run.",
  78. nargs='?',
  79. default='dlwpt',
  80. )
  81. self.cli_args = parser.parse_args(sys_argv)
  82. self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
  83. self.trn_writer = None
  84. self.val_writer = None
  85. self.totalTrainingSamples_count = 0
  86. self.augmentation_dict = {}
  87. if True:
  88. # if self.cli_args.augmented or self.cli_args.augment_flip:
  89. self.augmentation_dict['flip'] = True
  90. # if self.cli_args.augmented or self.cli_args.augment_offset:
  91. self.augmentation_dict['offset'] = 0.1
  92. # if self.cli_args.augmented or self.cli_args.augment_scale:
  93. self.augmentation_dict['scale'] = 0.2
  94. # if self.cli_args.augmented or self.cli_args.augment_rotate:
  95. self.augmentation_dict['rotate'] = True
  96. # if self.cli_args.augmented or self.cli_args.augment_noise:
  97. self.augmentation_dict['noise'] = 25.0
  98. self.use_cuda = torch.cuda.is_available()
  99. self.device = torch.device("cuda" if self.use_cuda else "cpu")
  100. self.model = self.initModel()
  101. self.optimizer = self.initOptimizer()
  102. def initModel(self):
  103. model_cls = getattr(p2ch14.model, self.cli_args.model)
  104. model = model_cls()
  105. if self.cli_args.finetune:
  106. d = torch.load(self.cli_args.finetune, map_location='cpu')
  107. model_blocks = [
  108. n for n, subm in model.named_children()
  109. if len(list(subm.parameters())) > 0
  110. ]
  111. finetune_blocks = model_blocks[-self.cli_args.finetune_depth:]
  112. log.info(f"finetuning from {self.cli_args.finetune}, blocks {' '.join(finetune_blocks)}")
  113. model.load_state_dict(
  114. {
  115. k: v for k,v in d['model_state'].items()
  116. if k.split('.')[0] not in model_blocks[-1]
  117. },
  118. strict=False,
  119. )
  120. for n, p in model.named_parameters():
  121. if n.split('.')[0] not in finetune_blocks:
  122. p.requires_grad_(False)
  123. if self.use_cuda:
  124. log.info("Using CUDA with {} devices.".format(torch.cuda.device_count()))
  125. if torch.cuda.device_count() > 1:
  126. model = nn.DataParallel(model)
  127. model = model.to(self.device)
  128. return model
  129. def initOptimizer(self):
  130. lr = 0.003 if self.cli_args.finetune else 0.001
  131. return SGD(self.model.parameters(), lr=lr, weight_decay=1e-4)
  132. #return Adam(self.model.parameters(), lr=3e-4)
  133. def initTrainDl(self):
  134. ds_cls = getattr(p2ch14.dsets, self.cli_args.dataset)
  135. train_ds = ds_cls(
  136. val_stride=10,
  137. isValSet_bool=False,
  138. ratio_int=1,
  139. )
  140. batch_size = self.cli_args.batch_size
  141. if self.use_cuda:
  142. batch_size *= torch.cuda.device_count()
  143. train_dl = DataLoader(
  144. train_ds,
  145. batch_size=batch_size,
  146. num_workers=self.cli_args.num_workers,
  147. pin_memory=self.use_cuda,
  148. )
  149. return train_dl
  150. def initValDl(self):
  151. ds_cls = getattr(p2ch14.dsets, self.cli_args.dataset)
  152. val_ds = ds_cls(
  153. val_stride=10,
  154. isValSet_bool=True,
  155. )
  156. batch_size = self.cli_args.batch_size
  157. if self.use_cuda:
  158. batch_size *= torch.cuda.device_count()
  159. val_dl = DataLoader(
  160. val_ds,
  161. batch_size=batch_size,
  162. num_workers=self.cli_args.num_workers,
  163. pin_memory=self.use_cuda,
  164. )
  165. return val_dl
  166. def initTensorboardWriters(self):
  167. if self.trn_writer is None:
  168. log_dir = os.path.join('runs', self.cli_args.tb_prefix,
  169. self.time_str)
  170. self.trn_writer = SummaryWriter(
  171. log_dir=log_dir + '-trn_cls-' + self.cli_args.comment)
  172. self.val_writer = SummaryWriter(
  173. log_dir=log_dir + '-val_cls-' + self.cli_args.comment)
  174. def main(self):
  175. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  176. train_dl = self.initTrainDl()
  177. val_dl = self.initValDl()
  178. best_score = 0.0
  179. validation_cadence = 5 if not self.cli_args.finetune else 1
  180. for epoch_ndx in range(1, self.cli_args.epochs + 1):
  181. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  182. epoch_ndx,
  183. self.cli_args.epochs,
  184. len(train_dl),
  185. len(val_dl),
  186. self.cli_args.batch_size,
  187. (torch.cuda.device_count() if self.use_cuda else 1),
  188. ))
  189. trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
  190. self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
  191. if epoch_ndx == 1 or epoch_ndx % validation_cadence == 0:
  192. valMetrics_t = self.doValidation(epoch_ndx, val_dl)
  193. score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
  194. best_score = max(score, best_score)
  195. # TODO: this 'cls' will need to change for the malignant classifier
  196. self.saveModel('cls', epoch_ndx, score == best_score)
  197. if hasattr(self, 'trn_writer'):
  198. self.trn_writer.close()
  199. self.val_writer.close()
  200. def doTraining(self, epoch_ndx, train_dl):
  201. self.model.train()
  202. train_dl.dataset.shuffleSamples()
  203. trnMetrics_g = torch.zeros(
  204. METRICS_SIZE,
  205. len(train_dl.dataset),
  206. device=self.device,
  207. )
  208. batch_iter = enumerateWithEstimate(
  209. train_dl,
  210. "E{} Training".format(epoch_ndx),
  211. start_ndx=train_dl.num_workers,
  212. )
  213. for batch_ndx, batch_tup in batch_iter:
  214. self.optimizer.zero_grad()
  215. loss_var = self.computeBatchLoss(
  216. batch_ndx,
  217. batch_tup,
  218. train_dl.batch_size,
  219. trnMetrics_g,
  220. augment=True
  221. )
  222. loss_var.backward()
  223. self.optimizer.step()
  224. self.totalTrainingSamples_count += len(train_dl.dataset)
  225. return trnMetrics_g.to('cpu')
  226. def doValidation(self, epoch_ndx, val_dl):
  227. with torch.no_grad():
  228. self.model.eval()
  229. valMetrics_g = torch.zeros(
  230. METRICS_SIZE,
  231. len(val_dl.dataset),
  232. device=self.device,
  233. )
  234. batch_iter = enumerateWithEstimate(
  235. val_dl,
  236. "E{} Validation ".format(epoch_ndx),
  237. start_ndx=val_dl.num_workers,
  238. )
  239. for batch_ndx, batch_tup in batch_iter:
  240. self.computeBatchLoss(
  241. batch_ndx,
  242. batch_tup,
  243. val_dl.batch_size,
  244. valMetrics_g,
  245. augment=False
  246. )
  247. return valMetrics_g.to('cpu')
  248. def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g,
  249. augment=True):
  250. input_t, label_t, index_t, _series_list, _center_list = batch_tup
  251. input_g = input_t.to(self.device, non_blocking=True)
  252. label_g = label_t.to(self.device, non_blocking=True)
  253. index_g = index_t.to(self.device, non_blocking=True)
  254. if augment:
  255. input_g = p2ch14.model.augment3d(input_g)
  256. logits_g, probability_g = self.model(input_g)
  257. loss_g = nn.functional.cross_entropy(logits_g, label_g[:, 1],
  258. reduction="none")
  259. start_ndx = batch_ndx * batch_size
  260. end_ndx = start_ndx + label_t.size(0)
  261. _, predLabel_g = torch.max(probability_g, dim=1, keepdim=False,
  262. out=None)
  263. # log.debug(index_g)
  264. metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = index_g
  265. metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = predLabel_g
  266. # metrics_g[METRICS_PRED_N_NDX, start_ndx:end_ndx] = probability_g[:,0]
  267. metrics_g[METRICS_PRED_P_NDX, start_ndx:end_ndx] = probability_g[:,1]
  268. # metrics_g[METRICS_PRED_M_NDX, start_ndx:end_ndx] = probability_g[:,2]
  269. metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_g
  270. return loss_g.mean()
  271. def logMetrics(
  272. self,
  273. epoch_ndx,
  274. mode_str,
  275. metrics_t,
  276. classificationThreshold=0.5,
  277. ):
  278. self.initTensorboardWriters()
  279. log.info("E{} {}".format(
  280. epoch_ndx,
  281. type(self).__name__,
  282. ))
  283. if self.cli_args.dataset == 'MalignantLunaDataset':
  284. pos = 'mal'
  285. neg = 'ben'
  286. else:
  287. pos = 'pos'
  288. neg = 'neg'
  289. negLabel_mask = metrics_t[METRICS_LABEL_NDX] == 0
  290. negPred_mask = metrics_t[METRICS_PRED_NDX] == 0
  291. posLabel_mask = ~negLabel_mask
  292. posPred_mask = ~negPred_mask
  293. # benLabel_mask = metrics_t[METRICS_LABEL_NDX] == 1
  294. # benPred_mask = metrics_t[METRICS_PRED_NDX] == 1
  295. #
  296. # malLabel_mask = metrics_t[METRICS_LABEL_NDX] == 2
  297. # malPred_mask = metrics_t[METRICS_PRED_NDX] == 2
  298. # benLabel_mask = ~malLabel_mask & posLabel_mask
  299. # benPred_mask = ~malPred_mask & posLabel_mask
  300. neg_count = int(negLabel_mask.sum())
  301. pos_count = int(posLabel_mask.sum())
  302. # ben_count = int(benLabel_mask.sum())
  303. # mal_count = int(malLabel_mask.sum())
  304. neg_correct = int((negLabel_mask & negPred_mask).sum())
  305. pos_correct = int((posLabel_mask & posPred_mask).sum())
  306. # ben_correct = int((benLabel_mask & benPred_mask).sum())
  307. # mal_correct = int((malLabel_mask & malPred_mask).sum())
  308. trueNeg_count = neg_correct
  309. truePos_count = pos_correct
  310. falsePos_count = neg_count - neg_correct
  311. falseNeg_count = pos_count - pos_correct
  312. metrics_dict = {}
  313. metrics_dict['loss/all'] = metrics_t[METRICS_LOSS_NDX].mean()
  314. metrics_dict['loss/neg'] = metrics_t[METRICS_LOSS_NDX, negLabel_mask].mean()
  315. metrics_dict['loss/pos'] = metrics_t[METRICS_LOSS_NDX, posLabel_mask].mean()
  316. # metrics_dict['loss/ben'] = metrics_t[METRICS_LOSS_NDX, benLabel_mask].mean()
  317. # metrics_dict['loss/mal'] = metrics_t[METRICS_LOSS_NDX, malLabel_mask].mean()
  318. metrics_dict['correct/all'] = (pos_correct + neg_correct) / metrics_t.shape[1] * 100
  319. metrics_dict['correct/neg'] = (neg_correct) / neg_count * 100
  320. metrics_dict['correct/pos'] = (pos_correct) / pos_count * 100
  321. # metrics_dict['correct/ben'] = (ben_correct) / ben_count * 100
  322. # metrics_dict['correct/mal'] = (mal_correct) / mal_count * 100
  323. precision = metrics_dict['pr/precision'] = \
  324. truePos_count / np.float64(truePos_count + falsePos_count)
  325. recall = metrics_dict['pr/recall'] = \
  326. truePos_count / np.float64(truePos_count + falseNeg_count)
  327. metrics_dict['pr/f1_score'] = \
  328. 2 * (precision * recall) / (precision + recall)
  329. threshold = torch.linspace(1, 0)
  330. tpr = (metrics_t[None, METRICS_PRED_P_NDX, posLabel_mask] >= threshold[:, None]).sum(1).float() / pos_count
  331. fpr = (metrics_t[None, METRICS_PRED_P_NDX, negLabel_mask] >= threshold[:, None]).sum(1).float() / neg_count
  332. fp_diff = fpr[1:]-fpr[:-1]
  333. tp_avg = (tpr[1:]+tpr[:-1])/2
  334. auc = (fp_diff * tp_avg).sum()
  335. metrics_dict['auc'] = auc
  336. log.info(
  337. ("E{} {:8} {loss/all:.4f} loss, "
  338. + "{correct/all:-5.1f}% correct, "
  339. + "{pr/precision:.4f} precision, "
  340. + "{pr/recall:.4f} recall, "
  341. + "{pr/f1_score:.4f} f1 score, "
  342. + "{auc:.4f} auc"
  343. ).format(
  344. epoch_ndx,
  345. mode_str,
  346. **metrics_dict,
  347. )
  348. )
  349. log.info(
  350. ("E{} {:8} {loss/neg:.4f} loss, "
  351. + "{correct/neg:-5.1f}% correct ({neg_correct:} of {neg_count:})"
  352. ).format(
  353. epoch_ndx,
  354. mode_str + '_' + neg,
  355. neg_correct=neg_correct,
  356. neg_count=neg_count,
  357. **metrics_dict,
  358. )
  359. )
  360. log.info(
  361. ("E{} {:8} {loss/pos:.4f} loss, "
  362. + "{correct/pos:-5.1f}% correct ({pos_correct:} of {pos_count:})"
  363. ).format(
  364. epoch_ndx,
  365. mode_str + '_' + pos,
  366. pos_correct=pos_correct,
  367. pos_count=pos_count,
  368. **metrics_dict,
  369. )
  370. )
  371. # log.info(
  372. # ("E{} {:8} {loss/ben:.4f} loss, "
  373. # + "{correct/ben:-5.1f}% correct ({ben_correct:} of {ben_count:})"
  374. # ).format(
  375. # epoch_ndx,
  376. # mode_str + '_ben',
  377. # ben_correct=ben_correct,
  378. # ben_count=ben_count,
  379. # **metrics_dict,
  380. # )
  381. # )
  382. # log.info(
  383. # ("E{} {:8} {loss/mal:.4f} loss, "
  384. # + "{correct/mal:-5.1f}% correct ({mal_correct:} of {mal_count:})"
  385. # ).format(
  386. # epoch_ndx,
  387. # mode_str + '_mal',
  388. # mal_correct=mal_correct,
  389. # mal_count=mal_count,
  390. # **metrics_dict,
  391. # )
  392. # )
  393. writer = getattr(self, mode_str + '_writer')
  394. for key, value in metrics_dict.items():
  395. key = key.replace('pos', pos)
  396. key = key.replace('neg', neg)
  397. writer.add_scalar(key, value, self.totalTrainingSamples_count)
  398. fig = pyplot.figure()
  399. pyplot.plot(fpr, tpr)
  400. writer.add_figure('roc', fig, self.totalTrainingSamples_count)
  401. writer.add_scalar('auc', auc, self.totalTrainingSamples_count)
  402. # # tag::logMetrics_writer_prcurve[]
  403. # writer.add_pr_curve(
  404. # 'pr',
  405. # metrics_t[METRICS_LABEL_NDX],
  406. # metrics_t[METRICS_PRED_P_NDX],
  407. # self.totalTrainingSamples_count,
  408. # )
  409. # # end::logMetrics_writer_prcurve[]
  410. bins = np.linspace(0, 1)
  411. writer.add_histogram(
  412. 'label_neg',
  413. metrics_t[METRICS_PRED_P_NDX, negLabel_mask],
  414. self.totalTrainingSamples_count,
  415. bins=bins
  416. )
  417. writer.add_histogram(
  418. 'label_pos',
  419. metrics_t[METRICS_PRED_P_NDX, posLabel_mask],
  420. self.totalTrainingSamples_count,
  421. bins=bins
  422. )
  423. if not self.cli_args.malignant:
  424. score = metrics_dict['pr/f1_score']
  425. else:
  426. score = metrics_dict['auc']
  427. return score
  428. def saveModel(self, type_str, epoch_ndx, isBest=False):
  429. file_path = os.path.join(
  430. 'data-unversioned',
  431. 'part2',
  432. 'models',
  433. self.cli_args.tb_prefix,
  434. '{}_{}_{}.{}.state'.format(
  435. type_str,
  436. self.time_str,
  437. self.cli_args.comment,
  438. self.totalTrainingSamples_count,
  439. )
  440. )
  441. os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
  442. model = self.model
  443. if isinstance(model, torch.nn.DataParallel):
  444. model = model.module
  445. state = {
  446. 'model_state': model.state_dict(),
  447. 'model_name': type(model).__name__,
  448. 'optimizer_state' : self.optimizer.state_dict(),
  449. 'optimizer_name': type(self.optimizer).__name__,
  450. 'epoch': epoch_ndx,
  451. 'totalTrainingSamples_count': self.totalTrainingSamples_count,
  452. }
  453. torch.save(state, file_path)
  454. log.debug("Saved model params to {}".format(file_path))
  455. if isBest:
  456. best_path = os.path.join(
  457. 'data-unversioned',
  458. 'part2',
  459. 'models',
  460. self.cli_args.tb_prefix,
  461. '{}_{}_{}.{}.state'.format(
  462. type_str,
  463. self.time_str,
  464. self.cli_args.comment,
  465. 'best',
  466. )
  467. )
  468. shutil.copyfile(file_path, best_path)
  469. log.debug("Saved model params to {}".format(best_path))
  470. with open(file_path, 'rb') as f:
  471. log.info("SHA1: " + hashlib.sha1(f.read()).hexdigest())
  472. # def logModelMetrics(self, model):
  473. # writer = getattr(self, 'trn_writer')
  474. #
  475. # model = getattr(model, 'module', model)
  476. #
  477. # for name, param in model.named_parameters():
  478. # if param.requires_grad:
  479. # min_data = float(param.data.min())
  480. # max_data = float(param.data.max())
  481. # max_extent = max(abs(min_data), abs(max_data))
  482. #
  483. # # bins = [x/50*max_extent for x in range(-50, 51)]
  484. #
  485. # try:
  486. # writer.add_histogram(
  487. # name.rsplit('.', 1)[-1] + '/' + name,
  488. # param.data.cpu().numpy(),
  489. # # metrics_a[METRICS_PRED_NDX, negHist_mask],
  490. # self.totalTrainingSamples_count,
  491. # # bins=bins,
  492. # )
  493. # except Exception as e:
  494. # log.error([min_data, max_data])
  495. # raise
  496. if __name__ == '__main__':
  497. ClassificationTrainingApp().main()