model.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import math
  2. from torch import nn as nn
  3. from util.logconf import logging
  4. log = logging.getLogger(__name__)
  5. # log.setLevel(logging.WARN)
  6. # log.setLevel(logging.INFO)
  7. log.setLevel(logging.DEBUG)
  8. class LunaModel(nn.Module):
  9. def __init__(self, in_channels=1, conv_channels=8):
  10. super().__init__()
  11. self.tail_batchnorm = nn.BatchNorm3d(1)
  12. self.block1 = LunaBlock(in_channels, conv_channels)
  13. self.block2 = LunaBlock(conv_channels, conv_channels * 2)
  14. self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4)
  15. self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8)
  16. self.head_linear = nn.Linear(1152, 2)
  17. self.head_softmax = nn.Softmax(dim=1)
  18. self._init_weights()
  19. # see also https://github.com/pytorch/pytorch/issues/18182
  20. def _init_weights(self):
  21. for m in self.modules():
  22. if type(m) in {
  23. nn.Linear,
  24. nn.Conv3d,
  25. nn.Conv2d,
  26. nn.ConvTranspose2d,
  27. nn.ConvTranspose3d,
  28. }:
  29. nn.init.kaiming_normal_(
  30. m.weight.data, a=0, mode='fan_out', nonlinearity='relu',
  31. )
  32. if m.bias is not None:
  33. fan_in, fan_out = \
  34. nn.init._calculate_fan_in_and_fan_out(m.weight.data)
  35. bound = 1 / math.sqrt(fan_out)
  36. nn.init.normal_(m.bias, -bound, bound)
  37. def forward(self, input_batch):
  38. bn_output = self.tail_batchnorm(input_batch)
  39. block_out = self.block1(bn_output)
  40. block_out = self.block2(block_out)
  41. block_out = self.block3(block_out)
  42. block_out = self.block4(block_out)
  43. conv_flat = block_out.view(
  44. block_out.size(0),
  45. -1,
  46. )
  47. linear_output = self.head_linear(conv_flat)
  48. return linear_output, self.head_softmax(linear_output)
  49. class LunaBlock(nn.Module):
  50. def __init__(self, in_channels, conv_channels):
  51. super().__init__()
  52. self.conv1 = nn.Conv3d(
  53. in_channels, conv_channels, kernel_size=3, padding=1, bias=True,
  54. )
  55. self.relu1 = nn.ReLU(inplace=True)
  56. self.conv2 = nn.Conv3d(
  57. conv_channels, conv_channels, kernel_size=3, padding=1, bias=True,
  58. )
  59. self.relu2 = nn.ReLU(inplace=True)
  60. self.maxpool = nn.MaxPool3d(2, 2)
  61. def forward(self, input_batch):
  62. block_out = self.conv1(input_batch)
  63. block_out = self.relu1(block_out)
  64. block_out = self.conv2(block_out)
  65. block_out = self.relu2(block_out)
  66. return self.maxpool(block_out)