model.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import math
  2. from torch import nn as nn
  3. from util.logconf import logging
  4. log = logging.getLogger(__name__)
  5. # log.setLevel(logging.WARN)
  6. # log.setLevel(logging.INFO)
  7. log.setLevel(logging.DEBUG)
  8. class LunaModel(nn.Module):
  9. def __init__(self, in_channels=1, conv_channels=8):
  10. super().__init__()
  11. self.tail_batchnorm = nn.BatchNorm3d(1)
  12. self.block1 = LunaBlock(in_channels, conv_channels)
  13. self.block2 = LunaBlock(conv_channels, conv_channels * 2)
  14. self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4)
  15. self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8)
  16. self.head_linear = nn.Linear(1152, 2)
  17. self.head_softmax = nn.Softmax(dim=1)
  18. self._init_weights()
  19. # see also https://github.com/pytorch/pytorch/issues/18182
  20. def _init_weights(self):
  21. for m in self.modules():
  22. if type(m) in {
  23. nn.Linear,
  24. nn.Conv3d,
  25. nn.Conv2d,
  26. nn.ConvTranspose2d,
  27. nn.ConvTranspose3d,
  28. }:
  29. nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='relu')
  30. if m.bias is not None:
  31. fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
  32. bound = 1 / math.sqrt(fan_out)
  33. nn.init.normal_(m.bias, -bound, bound)
  34. def forward(self, input_batch):
  35. bn_output = self.tail_batchnorm(input_batch)
  36. block_out = self.block1(bn_output)
  37. block_out = self.block2(block_out)
  38. block_out = self.block3(block_out)
  39. block_out = self.block4(block_out)
  40. conv_flat = block_out.view(
  41. block_out.size(0),
  42. -1,
  43. )
  44. linear_output = self.head_linear(conv_flat)
  45. return linear_output, self.head_softmax(linear_output)
  46. class LunaBlock(nn.Module):
  47. def __init__(self, in_channels, conv_channels):
  48. super().__init__()
  49. self.conv1 = nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True)
  50. self.relu1 = nn.ReLU(inplace=True)
  51. self.conv2 = nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True)
  52. self.relu2 = nn.ReLU(inplace=True)
  53. self.maxpool = nn.MaxPool3d(2, 2)
  54. def forward(self, input_batch):
  55. block_out = self.conv1(input_batch)
  56. block_out = self.relu1(block_out)
  57. block_out = self.conv2(block_out)
  58. block_out = self.relu2(block_out)
  59. return self.maxpool(block_out)