model.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from torch import nn as nn
  2. from util.logconf import logging
  3. log = logging.getLogger(__name__)
  4. # log.setLevel(logging.WARN)
  5. # log.setLevel(logging.INFO)
  6. log.setLevel(logging.DEBUG)
  7. class LunaModel(nn.Module):
  8. def __init__(self, layer_count, in_channels, conv_channels):
  9. super().__init__()
  10. layer_list = []
  11. for layer_ndx in range(layer_count):
  12. layer_list += [
  13. nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True),
  14. # nn.BatchNorm3d(conv_channels),
  15. nn.ReLU(inplace=True),
  16. nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True),
  17. # nn.BatchNorm3d(conv_channels),
  18. nn.ReLU(inplace=True),
  19. nn.MaxPool3d(2, 2),
  20. ]
  21. in_channels = conv_channels
  22. conv_channels *= 2
  23. self.convAndPool_seq = nn.Sequential(*layer_list)
  24. self.fullyConnected_layer = nn.Linear(256, 1)
  25. def forward(self, x):
  26. conv_out = self.convAndPool_seq(x)
  27. flattened_out = conv_out.view(conv_out.size(0), -1)
  28. try:
  29. classification_out = self.fullyConnected_layer(flattened_out)
  30. except:
  31. log.debug(flattened_out.size())
  32. raise
  33. return classification_out