| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- import argparse
- import datetime
- import os
- import sys
- import numpy as np
- from tensorboardX import SummaryWriter
- import torch
- import torch.nn as nn
- from torch.autograd import Variable
- from torch.optim import SGD
- from torch.utils.data import DataLoader
- from util.util import enumerateWithEstimate
- from .dsets import LunaDataset
- from util.logconf import logging
- from .model import LunaModel
- log = logging.getLogger(__name__)
- # log.setLevel(logging.WARN)
- log.setLevel(logging.INFO)
- # log.setLevel(logging.DEBUG)
- # Used for metrics_ary index 0
- LABEL=0
- PRED=1
- LOSS=2
- # ...
- class LunaTrainingApp(object):
- @classmethod
- def __init__(self, sys_argv=None):
- if sys_argv is None:
- sys_argv = sys.argv[1:]
- parser = argparse.ArgumentParser()
- parser.add_argument('--batch-size',
- help='Batch size to use for training',
- default=256,
- type=int,
- )
- parser.add_argument('--num-workers',
- help='Number of worker processes for background data loading',
- default=8,
- type=int,
- )
- parser.add_argument('--epochs',
- help='Number of epochs to train for',
- default=10,
- type=int,
- )
- parser.add_argument('--layers',
- help='Number of layers to the model',
- default=3,
- type=int,
- )
- parser.add_argument('--channels',
- help="Number of channels for the first layer's convolutions to the model (doubles each layer)",
- default=8,
- type=int,
- )
- parser.add_argument('--balanced',
- help="Balance the training data to half benign, half malignant.",
- action='store_true',
- default=False,
- )
- parser.add_argument('--scaled',
- help="Scale the CT chunks to square voxels.",
- action='store_true',
- default=False,
- )
- parser.add_argument('--augmented',
- help="Augment the training data (implies --scaled).",
- action='store_true',
- default=False,
- )
- parser.add_argument('--tb-prefix',
- help="Data prefix to use for Tensorboard. Defaults to chapter.",
- default='p2ch4',
- )
- self.cli_args = parser.parse_args(sys_argv)
- def main(self):
- log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
- self.train_dl = DataLoader(
- LunaDataset(
- test_stride=10,
- isTestSet_bool=False,
- balanced_bool=self.cli_args.balanced,
- scaled_bool=self.cli_args.scaled or self.cli_args.augmented,
- augmented_bool=self.cli_args.augmented,
- ),
- batch_size=self.cli_args.batch_size * torch.cuda.device_count(),
- num_workers=self.cli_args.num_workers,
- pin_memory=True,
- )
- self.test_dl = DataLoader(
- LunaDataset(
- test_stride=10,
- isTestSet_bool=True,
- scaled_bool=self.cli_args.scaled or self.cli_args.augmented,
- # augmented_bool=self.cli_args.augmented,
- ),
- batch_size=self.cli_args.batch_size * torch.cuda.device_count(),
- num_workers=self.cli_args.num_workers,
- pin_memory=True,
- )
- self.model = LunaModel(self.cli_args.layers, 1, self.cli_args.channels)
- self.model = nn.DataParallel(self.model)
- self.model = self.model.cuda()
- self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9)
- time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
- log_dir = os.path.join('runs', self.cli_args.tb_prefix, time_str)
- self.trn_writer = SummaryWriter(log_dir=log_dir + '_train')
- self.tst_writer = SummaryWriter(log_dir=log_dir + '_test')
- for epoch_ndx in range(1, self.cli_args.epochs + 1):
- log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
- epoch_ndx,
- self.cli_args.epochs,
- len(self.train_dl),
- len(self.test_dl),
- self.cli_args.batch_size,
- torch.cuda.device_count(),
- ))
- # Training loop, very similar to below
- self.model.train()
- self.train_dl.dataset.shuffleSamples()
- batch_iter = enumerateWithEstimate(
- self.train_dl,
- "E{} Training".format(epoch_ndx),
- start_ndx=self.train_dl.num_workers,
- )
- trainingMetrics_ary = np.zeros((3, len(self.train_dl.dataset)), dtype=np.float32)
- for batch_ndx, batch_tup in batch_iter:
- self.optimizer.zero_grad()
- loss_var = self.computeBatchLoss(batch_ndx, batch_tup, self.train_dl.batch_size, trainingMetrics_ary)
- loss_var.backward()
- self.optimizer.step()
- del loss_var
- # Testing loop, very similar to above, but simplified
- # ...
- self.model.eval()
- self.test_dl.dataset.shuffleSamples()
- batch_iter = enumerateWithEstimate(
- self.test_dl,
- "E{} Testing ".format(epoch_ndx),
- start_ndx=self.test_dl.num_workers,
- )
- testingMetrics_ary = np.zeros((3, len(self.test_dl.dataset)), dtype=np.float32)
- for batch_ndx, batch_tup in batch_iter:
- self.computeBatchLoss(batch_ndx, batch_tup, self.test_dl.batch_size, testingMetrics_ary)
- self.logMetrics(epoch_ndx, trainingMetrics_ary, testingMetrics_ary)
- self.trn_writer.close()
- self.tst_writer.close()
- def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_ary):
- input_tensor, label_tensor, series_list, center_list = batch_tup
- input_var = Variable(input_tensor.cuda())
- label_var = Variable(label_tensor.cuda())
- prediction_var = self.model(input_var)
- # ...
- start_ndx = batch_ndx * batch_size
- end_ndx = start_ndx + label_tensor.size(0)
- metrics_ary[LABEL, start_ndx:end_ndx] = label_tensor.numpy()[:,0,0]
- metrics_ary[PRED, start_ndx:end_ndx] = prediction_var.data.cpu().numpy()[:,0]
- for sample_ndx in range(label_tensor.size(0)):
- subloss_var = nn.MSELoss()(prediction_var[sample_ndx], label_var[sample_ndx])
- metrics_ary[LOSS, start_ndx+sample_ndx] = subloss_var.data[0]
- del subloss_var
- loss_var = nn.MSELoss()(prediction_var, label_var)
- return loss_var
- def logMetrics(self, epoch_ndx, trainingMetrics_ary, testingMetrics_ary):
- log.info("E{} {}".format(
- epoch_ndx,
- type(self).__name__,
- ))
- for mode_str, metrics_ary in [('trn', trainingMetrics_ary), ('tst', testingMetrics_ary)]:
- pos_mask = metrics_ary[LABEL] > 0.5
- neg_mask = ~pos_mask
- truePos_count = (metrics_ary[PRED, pos_mask] > 0.5).sum()
- trueNeg_count = (metrics_ary[PRED, neg_mask] < 0.5).sum()
- falseNeg_count = pos_mask.sum() - truePos_count
- falsePos_count = neg_mask.sum() - trueNeg_count
- metrics_dict = {}
- metrics_dict['pr/precision'] = p = truePos_count / (truePos_count + falsePos_count)
- metrics_dict['pr/recall'] = r = truePos_count / (truePos_count + falseNeg_count)
- # https://en.wikipedia.org/wiki/F1_score
- for n in [0.5, 1, 2]:
- metrics_dict['pr/f{}_score'.format(n)] = \
- (1 + n**2) * (p * r / (n**2 * p + r))
- metrics_dict['loss/all'] = metrics_ary[LOSS].mean()
- metrics_dict['loss/ben'] = metrics_ary[LOSS, neg_mask].mean()
- metrics_dict['loss/mal'] = metrics_ary[LOSS, pos_mask].mean()
- metrics_dict['correct/all'] = (truePos_count + trueNeg_count) / metrics_ary.shape[1] * 100
- metrics_dict['correct/ben'] = (trueNeg_count) / neg_mask.sum() * 100
- metrics_dict['correct/mal'] = (truePos_count) / pos_mask.sum() * 100
- log.info(("E{} {:8} "
- + "{loss/all:.4f} loss, "
- + "{correct/all:-5.1f}% correct, "
- + "{pr/precision:.4f} precision, "
- + "{pr/recall:.4f} recall").format(
- epoch_ndx,
- mode_str,
- **metrics_dict,
- ))
- log.info(("E{} {:8} "
- + "{loss/ben:.4f} loss, "
- + "{correct/ben:-5.1f}% correct").format(
- epoch_ndx,
- mode_str + '_ben',
- **metrics_dict,
- ))
- log.info(("E{} {:8} "
- + "{loss/mal:.4f} loss, "
- + "{correct/mal:-5.1f}% correct").format(
- epoch_ndx,
- mode_str + '_mal',
- **metrics_dict,
- ))
- writer = getattr(self, mode_str + '_writer')
- tb_ndx = epoch_ndx * trainingMetrics_ary.shape[1]
- for key, value in metrics_dict.items():
- writer.add_scalar(key, value, tb_ndx)
- writer.add_pr_curve('pr', metrics_ary[LABEL], metrics_ary[PRED], tb_ndx)
- writer.add_histogram('is_mal', metrics_ary[PRED, pos_mask], tb_ndx)
- writer.add_histogram('is_ben', metrics_ary[PRED, neg_mask], tb_ndx)
- if __name__ == '__main__':
- sys.exit(LunaTrainingApp().main() or 0)
|