model.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import math
  2. import torch
  3. from torch import nn as nn
  4. from util.logconf import logging
  5. log = logging.getLogger(__name__)
  6. # log.setLevel(logging.WARN)
  7. # log.setLevel(logging.INFO)
  8. log.setLevel(logging.DEBUG)
  9. class LunaModel(nn.Module):
  10. def __init__(self, layer_count=4, in_channels=1, conv_channels=8):
  11. super().__init__()
  12. self.input_batchnorm = nn.BatchNorm2d(1)
  13. layer_list = []
  14. for layer_ndx in range(layer_count):
  15. layer_list += [
  16. nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True),
  17. nn.ReLU(inplace=True),
  18. nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True),
  19. nn.ReLU(inplace=True),
  20. nn.MaxPool3d(2, 2),
  21. ]
  22. in_channels = conv_channels
  23. conv_channels *= 2
  24. self.convAndPool_seq = nn.Sequential(*layer_list)
  25. self.fullyConnected_layer = nn.Linear(576, 2)
  26. self.final = nn.Softmax(dim=1)
  27. self._init_weights()
  28. # see also https://github.com/pytorch/pytorch/issues/18182
  29. def _init_weights(self):
  30. for m in self.modules():
  31. if type(m) in {
  32. nn.Linear,
  33. nn.Conv3d,
  34. nn.Conv2d,
  35. nn.ConvTranspose2d,
  36. nn.ConvTranspose3d,
  37. }:
  38. nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='relu')
  39. if m.bias is not None:
  40. fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
  41. bound = 1 / math.sqrt(fan_out)
  42. nn.init.normal_(m.bias, -bound, bound)
  43. def forward(self, input_batch):
  44. bn_output = self.input_batchnorm(input_batch)
  45. conv_output = self.convAndPool_seq(bn_output)
  46. conv_flat = conv_output.view(conv_output.size(0), -1)
  47. classifier_output = self.fullyConnected_layer(conv_flat)
  48. return classifier_output, self.final(classifier_output)