| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597 |
- import argparse
- import datetime
- import hashlib
- import os
- import shutil
- import sys
- import numpy as np
- from matplotlib import pyplot
- from torch.utils.tensorboard import SummaryWriter
- import torch
- import torch.nn as nn
- from torch.optim import SGD, Adam
- from torch.utils.data import DataLoader
- import p2ch14.dsets
- import p2ch14.model
- from util.util import enumerateWithEstimate
- from util.logconf import logging
- log = logging.getLogger(__name__)
- # log.setLevel(logging.WARN)
- log.setLevel(logging.INFO)
- log.setLevel(logging.DEBUG)
- # Used for computeBatchLoss and logMetrics to index into metrics_t/metrics_a
- METRICS_LABEL_NDX=0
- METRICS_PRED_NDX=1
- METRICS_PRED_P_NDX=2
- METRICS_LOSS_NDX=3
- METRICS_SIZE = 4
- class ClassificationTrainingApp:
- def __init__(self, sys_argv=None):
- if sys_argv is None:
- sys_argv = sys.argv[1:]
- parser = argparse.ArgumentParser()
- parser.add_argument('--batch-size',
- help='Batch size to use for training',
- default=24,
- type=int,
- )
- parser.add_argument('--num-workers',
- help='Number of worker processes for background data loading',
- default=8,
- type=int,
- )
- parser.add_argument('--epochs',
- help='Number of epochs to train for',
- default=1,
- type=int,
- )
- parser.add_argument('--dataset',
- help="What to dataset to feed the model.",
- action='store',
- default='LunaDataset',
- )
- parser.add_argument('--model',
- help="What to model class name to use.",
- action='store',
- default='LunaModel',
- )
- parser.add_argument('--malignant',
- help="Train the model to classify nodules as benign or malignant.",
- action='store_true',
- default=False,
- )
- parser.add_argument('--finetune',
- help="Start finetuning from this model.",
- default='',
- )
- parser.add_argument('--finetune-depth',
- help="Number of blocks (counted from the head) to include in finetuning",
- type=int,
- default=1,
- )
- parser.add_argument('--tb-prefix',
- default='p2ch14',
- help="Data prefix to use for Tensorboard run. Defaults to chapter.",
- )
- parser.add_argument('comment',
- help="Comment suffix for Tensorboard run.",
- nargs='?',
- default='dlwpt',
- )
- self.cli_args = parser.parse_args(sys_argv)
- self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
- self.trn_writer = None
- self.val_writer = None
- self.totalTrainingSamples_count = 0
- self.augmentation_dict = {}
- if True:
- # if self.cli_args.augmented or self.cli_args.augment_flip:
- self.augmentation_dict['flip'] = True
- # if self.cli_args.augmented or self.cli_args.augment_offset:
- self.augmentation_dict['offset'] = 0.1
- # if self.cli_args.augmented or self.cli_args.augment_scale:
- self.augmentation_dict['scale'] = 0.2
- # if self.cli_args.augmented or self.cli_args.augment_rotate:
- self.augmentation_dict['rotate'] = True
- # if self.cli_args.augmented or self.cli_args.augment_noise:
- self.augmentation_dict['noise'] = 25.0
- self.use_cuda = torch.cuda.is_available()
- self.device = torch.device("cuda" if self.use_cuda else "cpu")
- self.model = self.initModel()
- self.optimizer = self.initOptimizer()
- def initModel(self):
- model_cls = getattr(p2ch14.model, self.cli_args.model)
- model = model_cls()
- if self.cli_args.finetune:
- d = torch.load(self.cli_args.finetune, map_location='cpu')
- model_blocks = [
- n for n, subm in model.named_children()
- if len(list(subm.parameters())) > 0
- ]
- finetune_blocks = model_blocks[-self.cli_args.finetune_depth:]
- log.info(f"finetuning from {self.cli_args.finetune}, blocks {' '.join(finetune_blocks)}")
- model.load_state_dict(
- {
- k: v for k,v in d['model_state'].items()
- if k.split('.')[0] not in model_blocks[-1]
- },
- strict=False,
- )
- for n, p in model.named_parameters():
- if n.split('.')[0] not in finetune_blocks:
- p.requires_grad_(False)
- if self.use_cuda:
- log.info("Using CUDA with {} devices.".format(torch.cuda.device_count()))
- if torch.cuda.device_count() > 1:
- model = nn.DataParallel(model)
- model = model.to(self.device)
- return model
- def initOptimizer(self):
- lr = 0.003 if self.cli_args.finetune else 0.001
- return SGD(self.model.parameters(), lr=lr, weight_decay=1e-4)
- #return Adam(self.model.parameters(), lr=3e-4)
- def initTrainDl(self):
- ds_cls = getattr(p2ch14.dsets, self.cli_args.dataset)
- train_ds = ds_cls(
- val_stride=10,
- isValSet_bool=False,
- ratio_int=1,
- )
- batch_size = self.cli_args.batch_size
- if self.use_cuda:
- batch_size *= torch.cuda.device_count()
- train_dl = DataLoader(
- train_ds,
- batch_size=batch_size,
- num_workers=self.cli_args.num_workers,
- pin_memory=self.use_cuda,
- )
- return train_dl
- def initValDl(self):
- ds_cls = getattr(p2ch14.dsets, self.cli_args.dataset)
- val_ds = ds_cls(
- val_stride=10,
- isValSet_bool=True,
- )
- batch_size = self.cli_args.batch_size
- if self.use_cuda:
- batch_size *= torch.cuda.device_count()
- val_dl = DataLoader(
- val_ds,
- batch_size=batch_size,
- num_workers=self.cli_args.num_workers,
- pin_memory=self.use_cuda,
- )
- return val_dl
- def initTensorboardWriters(self):
- if self.trn_writer is None:
- log_dir = os.path.join('runs', self.cli_args.tb_prefix,
- self.time_str)
- self.trn_writer = SummaryWriter(
- log_dir=log_dir + '-trn_cls-' + self.cli_args.comment)
- self.val_writer = SummaryWriter(
- log_dir=log_dir + '-val_cls-' + self.cli_args.comment)
- def main(self):
- log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
- train_dl = self.initTrainDl()
- val_dl = self.initValDl()
- best_score = 0.0
- validation_cadence = 5 if not self.cli_args.finetune else 1
- for epoch_ndx in range(1, self.cli_args.epochs + 1):
- log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
- epoch_ndx,
- self.cli_args.epochs,
- len(train_dl),
- len(val_dl),
- self.cli_args.batch_size,
- (torch.cuda.device_count() if self.use_cuda else 1),
- ))
- trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
- self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
- if epoch_ndx == 1 or epoch_ndx % validation_cadence == 0:
- valMetrics_t = self.doValidation(epoch_ndx, val_dl)
- score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
- best_score = max(score, best_score)
- # TODO: this 'cls' will need to change for the malignant classifier
- self.saveModel('cls', epoch_ndx, score == best_score)
- if hasattr(self, 'trn_writer'):
- self.trn_writer.close()
- self.val_writer.close()
- def doTraining(self, epoch_ndx, train_dl):
- self.model.train()
- train_dl.dataset.shuffleSamples()
- trnMetrics_g = torch.zeros(
- METRICS_SIZE,
- len(train_dl.dataset),
- device=self.device,
- )
- batch_iter = enumerateWithEstimate(
- train_dl,
- "E{} Training".format(epoch_ndx),
- start_ndx=train_dl.num_workers,
- )
- for batch_ndx, batch_tup in batch_iter:
- self.optimizer.zero_grad()
- loss_var = self.computeBatchLoss(
- batch_ndx,
- batch_tup,
- train_dl.batch_size,
- trnMetrics_g,
- augment=True
- )
- loss_var.backward()
- self.optimizer.step()
- self.totalTrainingSamples_count += len(train_dl.dataset)
- return trnMetrics_g.to('cpu')
- def doValidation(self, epoch_ndx, val_dl):
- with torch.no_grad():
- self.model.eval()
- valMetrics_g = torch.zeros(
- METRICS_SIZE,
- len(val_dl.dataset),
- device=self.device,
- )
- batch_iter = enumerateWithEstimate(
- val_dl,
- "E{} Validation ".format(epoch_ndx),
- start_ndx=val_dl.num_workers,
- )
- for batch_ndx, batch_tup in batch_iter:
- self.computeBatchLoss(
- batch_ndx,
- batch_tup,
- val_dl.batch_size,
- valMetrics_g,
- augment=False
- )
- return valMetrics_g.to('cpu')
- def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g,
- augment=True):
- input_t, label_t, index_t, _series_list, _center_list = batch_tup
- input_g = input_t.to(self.device, non_blocking=True)
- label_g = label_t.to(self.device, non_blocking=True)
- index_g = index_t.to(self.device, non_blocking=True)
- if augment:
- input_g = p2ch14.model.augment3d(input_g)
- logits_g, probability_g = self.model(input_g)
- loss_g = nn.functional.cross_entropy(logits_g, label_g[:, 1],
- reduction="none")
- start_ndx = batch_ndx * batch_size
- end_ndx = start_ndx + label_t.size(0)
- _, predLabel_g = torch.max(probability_g, dim=1, keepdim=False,
- out=None)
- # log.debug(index_g)
- metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = index_g
- metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = predLabel_g
- # metrics_g[METRICS_PRED_N_NDX, start_ndx:end_ndx] = probability_g[:,0]
- metrics_g[METRICS_PRED_P_NDX, start_ndx:end_ndx] = probability_g[:,1]
- # metrics_g[METRICS_PRED_M_NDX, start_ndx:end_ndx] = probability_g[:,2]
- metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_g
- return loss_g.mean()
- def logMetrics(
- self,
- epoch_ndx,
- mode_str,
- metrics_t,
- classificationThreshold=0.5,
- ):
- self.initTensorboardWriters()
- log.info("E{} {}".format(
- epoch_ndx,
- type(self).__name__,
- ))
- if self.cli_args.dataset == 'MalignantLunaDataset':
- pos = 'mal'
- neg = 'ben'
- else:
- pos = 'pos'
- neg = 'neg'
- negLabel_mask = metrics_t[METRICS_LABEL_NDX] == 0
- negPred_mask = metrics_t[METRICS_PRED_NDX] == 0
- posLabel_mask = ~negLabel_mask
- posPred_mask = ~negPred_mask
- # benLabel_mask = metrics_t[METRICS_LABEL_NDX] == 1
- # benPred_mask = metrics_t[METRICS_PRED_NDX] == 1
- #
- # malLabel_mask = metrics_t[METRICS_LABEL_NDX] == 2
- # malPred_mask = metrics_t[METRICS_PRED_NDX] == 2
- # benLabel_mask = ~malLabel_mask & posLabel_mask
- # benPred_mask = ~malPred_mask & posLabel_mask
- neg_count = int(negLabel_mask.sum())
- pos_count = int(posLabel_mask.sum())
- # ben_count = int(benLabel_mask.sum())
- # mal_count = int(malLabel_mask.sum())
- neg_correct = int((negLabel_mask & negPred_mask).sum())
- pos_correct = int((posLabel_mask & posPred_mask).sum())
- # ben_correct = int((benLabel_mask & benPred_mask).sum())
- # mal_correct = int((malLabel_mask & malPred_mask).sum())
- trueNeg_count = neg_correct
- truePos_count = pos_correct
- falsePos_count = neg_count - neg_correct
- falseNeg_count = pos_count - pos_correct
- metrics_dict = {}
- metrics_dict['loss/all'] = metrics_t[METRICS_LOSS_NDX].mean()
- metrics_dict['loss/neg'] = metrics_t[METRICS_LOSS_NDX, negLabel_mask].mean()
- metrics_dict['loss/pos'] = metrics_t[METRICS_LOSS_NDX, posLabel_mask].mean()
- # metrics_dict['loss/ben'] = metrics_t[METRICS_LOSS_NDX, benLabel_mask].mean()
- # metrics_dict['loss/mal'] = metrics_t[METRICS_LOSS_NDX, malLabel_mask].mean()
- metrics_dict['correct/all'] = (pos_correct + neg_correct) / metrics_t.shape[1] * 100
- metrics_dict['correct/neg'] = (neg_correct) / neg_count * 100
- metrics_dict['correct/pos'] = (pos_correct) / pos_count * 100
- # metrics_dict['correct/ben'] = (ben_correct) / ben_count * 100
- # metrics_dict['correct/mal'] = (mal_correct) / mal_count * 100
- precision = metrics_dict['pr/precision'] = \
- truePos_count / np.float64(truePos_count + falsePos_count)
- recall = metrics_dict['pr/recall'] = \
- truePos_count / np.float64(truePos_count + falseNeg_count)
- metrics_dict['pr/f1_score'] = \
- 2 * (precision * recall) / (precision + recall)
- threshold = torch.linspace(1, 0)
- tpr = (metrics_t[None, METRICS_PRED_P_NDX, posLabel_mask] >= threshold[:, None]).sum(1).float() / pos_count
- fpr = (metrics_t[None, METRICS_PRED_P_NDX, negLabel_mask] >= threshold[:, None]).sum(1).float() / neg_count
- fp_diff = fpr[1:]-fpr[:-1]
- tp_avg = (tpr[1:]+tpr[:-1])/2
- auc = (fp_diff * tp_avg).sum()
- metrics_dict['auc'] = auc
- log.info(
- ("E{} {:8} {loss/all:.4f} loss, "
- + "{correct/all:-5.1f}% correct, "
- + "{pr/precision:.4f} precision, "
- + "{pr/recall:.4f} recall, "
- + "{pr/f1_score:.4f} f1 score, "
- + "{auc:.4f} auc"
- ).format(
- epoch_ndx,
- mode_str,
- **metrics_dict,
- )
- )
- log.info(
- ("E{} {:8} {loss/neg:.4f} loss, "
- + "{correct/neg:-5.1f}% correct ({neg_correct:} of {neg_count:})"
- ).format(
- epoch_ndx,
- mode_str + '_' + neg,
- neg_correct=neg_correct,
- neg_count=neg_count,
- **metrics_dict,
- )
- )
- log.info(
- ("E{} {:8} {loss/pos:.4f} loss, "
- + "{correct/pos:-5.1f}% correct ({pos_correct:} of {pos_count:})"
- ).format(
- epoch_ndx,
- mode_str + '_' + pos,
- pos_correct=pos_correct,
- pos_count=pos_count,
- **metrics_dict,
- )
- )
- # log.info(
- # ("E{} {:8} {loss/ben:.4f} loss, "
- # + "{correct/ben:-5.1f}% correct ({ben_correct:} of {ben_count:})"
- # ).format(
- # epoch_ndx,
- # mode_str + '_ben',
- # ben_correct=ben_correct,
- # ben_count=ben_count,
- # **metrics_dict,
- # )
- # )
- # log.info(
- # ("E{} {:8} {loss/mal:.4f} loss, "
- # + "{correct/mal:-5.1f}% correct ({mal_correct:} of {mal_count:})"
- # ).format(
- # epoch_ndx,
- # mode_str + '_mal',
- # mal_correct=mal_correct,
- # mal_count=mal_count,
- # **metrics_dict,
- # )
- # )
- writer = getattr(self, mode_str + '_writer')
- for key, value in metrics_dict.items():
- key = key.replace('pos', pos)
- key = key.replace('neg', neg)
- writer.add_scalar(key, value, self.totalTrainingSamples_count)
- fig = pyplot.figure()
- pyplot.plot(fpr, tpr)
- writer.add_figure('roc', fig, self.totalTrainingSamples_count)
- writer.add_scalar('auc', auc, self.totalTrainingSamples_count)
- # # tag::logMetrics_writer_prcurve[]
- # writer.add_pr_curve(
- # 'pr',
- # metrics_t[METRICS_LABEL_NDX],
- # metrics_t[METRICS_PRED_P_NDX],
- # self.totalTrainingSamples_count,
- # )
- # # end::logMetrics_writer_prcurve[]
- bins = np.linspace(0, 1)
- writer.add_histogram(
- 'label_neg',
- metrics_t[METRICS_PRED_P_NDX, negLabel_mask],
- self.totalTrainingSamples_count,
- bins=bins
- )
- writer.add_histogram(
- 'label_pos',
- metrics_t[METRICS_PRED_P_NDX, posLabel_mask],
- self.totalTrainingSamples_count,
- bins=bins
- )
- if not self.cli_args.malignant:
- score = metrics_dict['pr/f1_score']
- else:
- score = metrics_dict['auc']
- return score
- def saveModel(self, type_str, epoch_ndx, isBest=False):
- file_path = os.path.join(
- 'data-unversioned',
- 'part2',
- 'models',
- self.cli_args.tb_prefix,
- '{}_{}_{}.{}.state'.format(
- type_str,
- self.time_str,
- self.cli_args.comment,
- self.totalTrainingSamples_count,
- )
- )
- os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
- model = self.model
- if isinstance(model, torch.nn.DataParallel):
- model = model.module
- state = {
- 'model_state': model.state_dict(),
- 'model_name': type(model).__name__,
- 'optimizer_state' : self.optimizer.state_dict(),
- 'optimizer_name': type(self.optimizer).__name__,
- 'epoch': epoch_ndx,
- 'totalTrainingSamples_count': self.totalTrainingSamples_count,
- }
- torch.save(state, file_path)
- log.debug("Saved model params to {}".format(file_path))
- if isBest:
- best_path = os.path.join(
- 'data-unversioned',
- 'part2',
- 'models',
- self.cli_args.tb_prefix,
- '{}_{}_{}.{}.state'.format(
- type_str,
- self.time_str,
- self.cli_args.comment,
- 'best',
- )
- )
- shutil.copyfile(file_path, best_path)
- log.debug("Saved model params to {}".format(best_path))
- with open(file_path, 'rb') as f:
- log.info("SHA1: " + hashlib.sha1(f.read()).hexdigest())
- # def logModelMetrics(self, model):
- # writer = getattr(self, 'trn_writer')
- #
- # model = getattr(model, 'module', model)
- #
- # for name, param in model.named_parameters():
- # if param.requires_grad:
- # min_data = float(param.data.min())
- # max_data = float(param.data.max())
- # max_extent = max(abs(min_data), abs(max_data))
- #
- # # bins = [x/50*max_extent for x in range(-50, 51)]
- #
- # try:
- # writer.add_histogram(
- # name.rsplit('.', 1)[-1] + '/' + name,
- # param.data.cpu().numpy(),
- # # metrics_a[METRICS_PRED_NDX, negHist_mask],
- # self.totalTrainingSamples_count,
- # # bins=bins,
- # )
- # except Exception as e:
- # log.error([min_data, max_data])
- # raise
- if __name__ == '__main__':
- ClassificationTrainingApp().main()
|