benchmark_seg.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import argparse
  2. import datetime
  3. import os
  4. import socket
  5. import sys
  6. import numpy as np
  7. from torch.utils.tensorboard import SummaryWriter
  8. import torch
  9. import torch.nn as nn
  10. import torch.optim
  11. from torch.optim import SGD, Adam
  12. from torch.utils.data import DataLoader
  13. from util.util import enumerateWithEstimate
  14. from p2ch13.dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt
  15. from util.logconf import logging
  16. from util.util import xyz2irc
  17. from p2ch13.model_seg import UNetWrapper, SegmentationAugmentation
  18. from p2ch13.train_seg import LunaTrainingApp
  19. log = logging.getLogger(__name__)
  20. # log.setLevel(logging.WARN)
  21. # log.setLevel(logging.INFO)
  22. log.setLevel(logging.DEBUG)
  23. class BenchmarkLuna2dSegmentationDataset(TrainingLuna2dSegmentationDataset):
  24. def __len__(self):
  25. # return 500
  26. return 5000
  27. return 1000
  28. class LunaBenchmarkApp(LunaTrainingApp):
  29. def initTrainDl(self):
  30. train_ds = BenchmarkLuna2dSegmentationDataset(
  31. val_stride=10,
  32. isValSet_bool=False,
  33. contextSlices_count=3,
  34. # augmentation_dict=self.augmentation_dict,
  35. )
  36. batch_size = self.cli_args.batch_size
  37. if self.use_cuda:
  38. batch_size *= torch.cuda.device_count()
  39. train_dl = DataLoader(
  40. train_ds,
  41. batch_size=batch_size,
  42. num_workers=self.cli_args.num_workers,
  43. pin_memory=self.use_cuda,
  44. )
  45. return train_dl
  46. def main(self):
  47. log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
  48. train_dl = self.initTrainDl()
  49. for epoch_ndx in range(1, 2):
  50. log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
  51. epoch_ndx,
  52. self.cli_args.epochs,
  53. len(train_dl),
  54. len([]),
  55. self.cli_args.batch_size,
  56. (torch.cuda.device_count() if self.use_cuda else 1),
  57. ))
  58. self.doTraining(epoch_ndx, train_dl)
  59. if __name__ == '__main__':
  60. LunaBenchmarkApp().main()