model.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. from torch import nn as nn
  2. from util.logconf import logging
  3. log = logging.getLogger(__name__)
  4. # log.setLevel(logging.WARN)
  5. # log.setLevel(logging.INFO)
  6. log.setLevel(logging.DEBUG)
  7. class LunaModel(nn.Module):
  8. def __init__(self, layer_count, in_channels, conv_channels):
  9. super().__init__()
  10. layer_list = []
  11. for layer_ndx in range(layer_count):
  12. layer_list += [
  13. nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True),
  14. nn.ReLU(inplace=True),
  15. nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True),
  16. nn.ReLU(inplace=True),
  17. nn.MaxPool3d(2, 2),
  18. ]
  19. in_channels = conv_channels
  20. conv_channels *= 2
  21. self.convAndPool_seq = nn.Sequential(*layer_list)
  22. self.fullyConnected_layer = nn.Linear(256, 1)
  23. def forward(self, x):
  24. conv_out = self.convAndPool_seq(x)
  25. flattened_out = conv_out.view(conv_out.size(0), -1)
  26. try:
  27. classification_out = self.fullyConnected_layer(flattened_out)
  28. except:
  29. log.debug(flattened_out.size())
  30. raise
  31. return classification_out