model.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import math
  2. import torch.nn as nn
  3. from util.logconf import logging
  4. log = logging.getLogger(__name__)
  5. # log.setLevel(logging.WARN)
  6. # log.setLevel(logging.INFO)
  7. log.setLevel(logging.DEBUG)
  8. class LunaModel(nn.Module):
  9. def __init__(self, layer_count=4, in_channels=1, conv_channels=8):
  10. super().__init__()
  11. self.input_batchnorm = nn.BatchNorm2d(1)
  12. layer_list = []
  13. for layer_ndx in range(layer_count):
  14. layer_list += [
  15. nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True),
  16. nn.ReLU(inplace=True),
  17. nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True),
  18. nn.ReLU(inplace=True),
  19. nn.MaxPool3d(2, 2),
  20. ]
  21. in_channels = conv_channels
  22. conv_channels *= 2
  23. self.convAndPool_seq = nn.Sequential(*layer_list)
  24. self.fullyConnected_layer = nn.Linear(576, 2)
  25. self.final = nn.Softmax(dim=1)
  26. self._init_weights()
  27. def _init_weights(self):
  28. # see also https://github.com/pytorch/pytorch/issues/18182
  29. for m in self.modules():
  30. if type(m) in {
  31. nn.Conv2d,
  32. nn.Conv3d,
  33. nn.ConvTranspose2d,
  34. nn.ConvTranspose3d,
  35. nn.Linear,
  36. }:
  37. # log.debug(m)
  38. # nn.init.kaiming_normal_(m.weight.data, mode='fan_out', a=0)
  39. nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='relu')
  40. if m.bias is not None:
  41. fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
  42. bound = 1 / math.sqrt(fan_out)
  43. nn.init.normal_(m.bias, -bound, bound)
  44. def forward(self, input_batch):
  45. bn_output = self.input_batchnorm(input_batch)
  46. conv_output = self.convAndPool_seq(bn_output)
  47. conv_flat = conv_output.view(conv_output.size(0), -1)
  48. classifier_output = self.fullyConnected_layer(conv_flat)
  49. return classifier_output, self.final(classifier_output)