| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216 |
- import argparse
- import datetime
- import os
- import sys
- import numpy as np
- from tensorboardX import SummaryWriter
- import torch
- import torch.nn as nn
- from torch.optim import SGD
- from torch.utils.data import DataLoader
- from util.util import enumerateWithEstimate
- from .dsets import LunaDataset
- from util.logconf import logging
- from .model import LunaModel
- log = logging.getLogger(__name__)
- # log.setLevel(logging.WARN)
- log.setLevel(logging.INFO)
- # log.setLevel(logging.DEBUG)
- # Used for computeBatchLoss and logMetrics to index into metrics_tensor/metrics_ary
- METRICS_LABEL_NDX=0
- METRICS_PRED_NDX=1
- METRICS_LOSS_NDX=2
- class LunaTrainingApp(object):
- def __init__(self, sys_argv=None):
- if sys_argv is None:
- sys_argv = sys.argv[1:]
- parser = argparse.ArgumentParser()
- parser.add_argument('--batch-size',
- help='Batch size to use for training',
- default=32,
- type=int,
- )
- parser.add_argument('--num-workers',
- help='Number of worker processes for background data loading',
- default=8,
- type=int,
- )
- parser.add_argument('--epochs',
- help='Number of epochs to train for',
- default=1,
- type=int,
- )
- self.cli_args = parser.parse_args(sys_argv)
- self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
- def main(self):
- log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
- self.use_cuda = torch.cuda.is_available()
- self.device = torch.device("cuda" if self.use_cuda else "cpu")
- self.model = LunaModel()
- if self.use_cuda:
- if torch.cuda.device_count() > 1:
- self.model = nn.DataParallel(self.model)
- self.model = self.model.to(self.device)
- self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9)
- train_dl = DataLoader(
- LunaDataset(
- test_stride=10,
- isTestSet_bool=False,
- ),
- batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
- num_workers=self.cli_args.num_workers,
- pin_memory=self.use_cuda,
- )
- test_dl = DataLoader(
- LunaDataset(
- test_stride=10,
- isTestSet_bool=True,
- ),
- batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
- num_workers=self.cli_args.num_workers,
- pin_memory=self.use_cuda,
- )
- for epoch_ndx in range(1, self.cli_args.epochs + 1):
- log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
- epoch_ndx,
- self.cli_args.epochs,
- len(train_dl),
- len(test_dl),
- self.cli_args.batch_size,
- (torch.cuda.device_count() if self.use_cuda else 1),
- ))
- # Training loop, very similar to below
- self.model.train()
- trainingMetrics_tensor = torch.zeros(3, len(train_dl.dataset), 1)
- batch_iter = enumerateWithEstimate(
- train_dl,
- "E{} Training".format(epoch_ndx),
- start_ndx=train_dl.num_workers,
- )
- for batch_ndx, batch_tup in batch_iter:
- self.optimizer.zero_grad()
- loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
- loss_var.backward()
- self.optimizer.step()
- del loss_var
- # Testing loop, very similar to above, but simplified
- with torch.no_grad():
- self.model.eval()
- testingMetrics_tensor = torch.zeros(3, len(test_dl.dataset), 1)
- batch_iter = enumerateWithEstimate(
- test_dl,
- "E{} Testing ".format(epoch_ndx),
- start_ndx=test_dl.num_workers,
- )
- for batch_ndx, batch_tup in batch_iter:
- self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
- self.logMetrics(epoch_ndx, trainingMetrics_tensor, testingMetrics_tensor)
- def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
- input_tensor, label_tensor, _series_list, _center_list = batch_tup
- input_devtensor = input_tensor.to(self.device)
- label_devtensor = label_tensor.to(self.device)
- prediction_devtensor = self.model(input_devtensor)
- loss_devtensor = nn.MSELoss(reduction='none')(prediction_devtensor, label_devtensor)
- start_ndx = batch_ndx * batch_size
- end_ndx = start_ndx + label_tensor.size(0)
- metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_tensor
- metrics_tensor[METRICS_PRED_NDX, start_ndx:end_ndx] = prediction_devtensor.to('cpu')
- metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor.to('cpu')
- # TODO: replace with torch.autograd.detect_anomaly
- # assert np.isfinite(metrics_tensor).all()
- return loss_devtensor.mean()
- def logMetrics(self,
- epoch_ndx,
- trainingMetrics_tensor,
- testingMetrics_tensor,
- classificationThreshold_float=0.5,
- ):
- log.info("E{} {}".format(
- epoch_ndx,
- type(self).__name__,
- ))
- for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
- metrics_ary = metrics_tensor.detach().numpy()[:,:,0]
- assert np.isfinite(metrics_ary).all()
- benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= classificationThreshold_float
- benPred_mask = metrics_ary[METRICS_PRED_NDX] <= classificationThreshold_float
- malLabel_mask = ~benLabel_mask
- malPred_mask = ~benPred_mask
- benLabel_count = benLabel_mask.sum()
- malLabel_count = malLabel_mask.sum()
- benCorrect_count = (benLabel_mask & benPred_mask).sum()
- malCorrect_count = (malLabel_mask & malPred_mask).sum()
- metrics_dict = {}
- metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
- metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, benLabel_mask].mean()
- metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, malLabel_mask].mean()
- metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_ary.shape[1] * 100
- metrics_dict['correct/ben'] = (benCorrect_count) / benLabel_count * 100
- metrics_dict['correct/mal'] = (malCorrect_count) / malLabel_count * 100
- log.info(("E{} {:8} "
- + "{loss/all:.4f} loss, "
- + "{correct/all:-5.1f}% correct"
- ).format(
- epoch_ndx,
- mode_str,
- **metrics_dict,
- ))
- log.info(("E{} {:8} "
- + "{loss/ben:.4f} loss, "
- + "{correct/ben:-5.1f}% correct").format(
- epoch_ndx,
- mode_str + '_ben',
- **metrics_dict,
- ))
- log.info(("E{} {:8} "
- + "{loss/mal:.4f} loss, "
- + "{correct/mal:-5.1f}% correct").format(
- epoch_ndx,
- mode_str + '_mal',
- **metrics_dict,
- ))
- if __name__ == '__main__':
- sys.exit(LunaTrainingApp().main() or 0)
|