| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- import math
- import torch.nn as nn
- from util.logconf import logging
- log = logging.getLogger(__name__)
- # log.setLevel(logging.WARN)
- # log.setLevel(logging.INFO)
- log.setLevel(logging.DEBUG)
- class LunaModel(nn.Module):
- def __init__(self, in_channels=1, conv_channels=8):
- super().__init__()
- self.tail_batchnorm = nn.BatchNorm3d(1)
- self.block1 = LunaBlock(in_channels, conv_channels)
- self.block2 = LunaBlock(conv_channels, conv_channels * 2)
- self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4)
- self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8)
- self.head_linear = nn.Linear(1152, 2)
- self.head_softmax = nn.Softmax(dim=1)
- self._init_weights()
- # see also https://github.com/pytorch/pytorch/issues/18182
- def _init_weights(self):
- for m in self.modules():
- if type(m) in {
- nn.Linear,
- nn.Conv3d,
- nn.Conv2d,
- nn.ConvTranspose2d,
- nn.ConvTranspose3d,
- }:
- nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='relu')
- if m.bias is not None:
- fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
- bound = 1 / math.sqrt(fan_out)
- nn.init.normal_(m.bias, -bound, bound)
- def forward(self, input_batch):
- bn_output = self.tail_batchnorm(input_batch)
- block_out = self.block1(bn_output)
- block_out = self.block2(block_out)
- block_out = self.block3(block_out)
- block_out = self.block4(block_out)
- conv_flat = block_out.view(
- block_out.size(0),
- -1,
- )
- linear_output = self.head_linear(conv_flat)
- return linear_output, self.head_softmax(linear_output)
- class LunaBlock(nn.Module):
- def __init__(self, in_channels, conv_channels):
- super().__init__()
- self.conv1 = nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True)
- self.relu1 = nn.ReLU(inplace=True)
- self.conv2 = nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True)
- self.relu2 = nn.ReLU(inplace=True)
- self.maxpool = nn.MaxPool3d(2, 2)
- def forward(self, input_batch):
- block_out = self.conv1(input_batch)
- block_out = self.relu1(block_out)
- block_out = self.conv2(block_out)
- block_out = self.relu2(block_out)
- return self.maxpool(block_out)
- class AlternateLunaModel(LunaModel):
- def __init__(self, in_channels=1, conv_channels=64):
- super().__init__()
- self.block1 = LunaBlock(in_channels, conv_channels)
- self.block2 = LunaBlock(conv_channels, conv_channels // 2)
- self.block3 = LunaBlock(conv_channels // 2, conv_channels // 4)
- self.block4 = LunaBlock(conv_channels // 4, conv_channels // 8)
- self.head_linear = nn.Linear(144, 2)
|