model.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import torch
  2. from torch import nn as nn
  3. from util.logconf import logging
  4. log = logging.getLogger(__name__)
  5. # log.setLevel(logging.WARN)
  6. # log.setLevel(logging.INFO)
  7. log.setLevel(logging.DEBUG)
  8. class LunaModel(nn.Module):
  9. def __init__(self, layer_count=4, in_channels=1, conv_channels=8):
  10. super().__init__()
  11. layer_list = []
  12. for layer_ndx in range(layer_count):
  13. layer_list += [
  14. nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=False),
  15. nn.BatchNorm3d(conv_channels), # eli: will assume that p1ch6 doesn't use this
  16. nn.LeakyReLU(inplace=True), # eli: will assume plan ReLU
  17. nn.Dropout3d(p=0.2), # eli: will assume that p1ch6 doesn't use this
  18. nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=False),
  19. nn.BatchNorm3d(conv_channels),
  20. nn.LeakyReLU(inplace=True),
  21. nn.Dropout3d(p=0.2),
  22. nn.MaxPool3d(2, 2),
  23. # tag::model_init[]
  24. ]
  25. in_channels = conv_channels
  26. conv_channels *= 2
  27. self.convAndPool_seq = nn.Sequential(*layer_list)
  28. self.fullyConnected_layer = nn.Linear(512, 1)
  29. self.final = nn.Hardtanh(min_val=0.0, max_val=1.0)
  30. def forward(self, input_batch):
  31. conv_output = self.convAndPool_seq(input_batch)
  32. conv_flat = conv_output.view(conv_output.size(0), -1)
  33. try:
  34. classifier_output = self.fullyConnected_layer(conv_flat)
  35. except:
  36. log.debug(conv_flat.size())
  37. raise
  38. classifier_output = self.final(classifier_output)
  39. return classifier_output