| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- import argparse
- import datetime
- import os
- import socket
- import sys
- import numpy as np
- from torch.utils.tensorboard import SummaryWriter
- import torch
- import torch.nn as nn
- import torch.optim
- from torch.optim import SGD, Adam
- from torch.utils.data import DataLoader
- from util.util import enumerateWithEstimate
- from p2ch13.dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt
- from util.logconf import logging
- from util.util import xyz2irc
- from p2ch13.model_seg import UNetWrapper, SegmentationAugmentation
- from p2ch13.train_seg import LunaTrainingApp
- log = logging.getLogger(__name__)
- # log.setLevel(logging.WARN)
- # log.setLevel(logging.INFO)
- log.setLevel(logging.DEBUG)
- class BenchmarkLuna2dSegmentationDataset(TrainingLuna2dSegmentationDataset):
- def __len__(self):
- # return 500
- return 5000
- return 1000
- class LunaBenchmarkApp(LunaTrainingApp):
- def initTrainDl(self):
- train_ds = BenchmarkLuna2dSegmentationDataset(
- val_stride=10,
- isValSet_bool=False,
- contextSlices_count=3,
- # augmentation_dict=self.augmentation_dict,
- )
- batch_size = self.cli_args.batch_size
- if self.use_cuda:
- batch_size *= torch.cuda.device_count()
- train_dl = DataLoader(
- train_ds,
- batch_size=batch_size,
- num_workers=self.cli_args.num_workers,
- pin_memory=self.use_cuda,
- )
- return train_dl
- def main(self):
- log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
- train_dl = self.initTrainDl()
- for epoch_ndx in range(1, 2):
- log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
- epoch_ndx,
- self.cli_args.epochs,
- len(train_dl),
- len([]),
- self.cli_args.batch_size,
- (torch.cuda.device_count() if self.use_cuda else 1),
- ))
- self.doTraining(epoch_ndx, train_dl)
- if __name__ == '__main__':
- LunaBenchmarkApp().main()
|