model.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import torch
  2. from torch import nn as nn
  3. from util.logconf import logging
  4. from util.unet import UNet
  5. log = logging.getLogger(__name__)
  6. # log.setLevel(logging.WARN)
  7. # log.setLevel(logging.INFO)
  8. log.setLevel(logging.DEBUG)
  9. class LunaModel(nn.Module):
  10. def __init__(self, layer_count=4, in_channels=1, conv_channels=8):
  11. super().__init__()
  12. layer_list = []
  13. for layer_ndx in range(layer_count):
  14. layer_list += [
  15. nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=False),
  16. nn.BatchNorm3d(conv_channels), # eli: will assume that p1ch6 doesn't use this
  17. nn.LeakyReLU(inplace=True), # eli: will assume plan ReLU
  18. nn.Dropout3d(p=0.2), # eli: will assume that p1ch6 doesn't use this
  19. nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=False),
  20. nn.BatchNorm3d(conv_channels),
  21. nn.LeakyReLU(inplace=True),
  22. nn.Dropout3d(p=0.2),
  23. nn.MaxPool3d(2, 2),
  24. # tag::model_init[]
  25. ]
  26. in_channels = conv_channels
  27. conv_channels *= 2
  28. self.convAndPool_seq = nn.Sequential(*layer_list)
  29. self.fullyConnected_layer = nn.Linear(512, 1)
  30. self.final = nn.Hardtanh(min_val=0.0, max_val=1.0)
  31. def forward(self, input_batch):
  32. conv_output = self.convAndPool_seq(input_batch)
  33. conv_flat = conv_output.view(conv_output.size(0), -1)
  34. try:
  35. classifier_output = self.fullyConnected_layer(conv_flat)
  36. except:
  37. log.debug(conv_flat.size())
  38. raise
  39. classifier_output = self.final(classifier_output)
  40. return classifier_output
  41. class UNetWrapper(nn.Module):
  42. def __init__(self, **kwargs):
  43. super().__init__()
  44. self.batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
  45. self.unet = UNet(**kwargs)
  46. self.hardtanh = nn.Hardtanh(min_val=0, max_val=1)
  47. def forward(self, input):
  48. bn_output = self.batchnorm(input)
  49. un_output = self.unet(bn_output)
  50. ht_output = self.hardtanh(un_output)
  51. return ht_output