| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- import math
- import numpy as np
- import torch
- from torch import nn as nn
- from util.logconf import logging
- log = logging.getLogger(__name__)
- # log.setLevel(logging.WARN)
- # log.setLevel(logging.INFO)
- log.setLevel(logging.DEBUG)
- import random
- def augment3d(inp):
- transform_t = torch.eye(4, dtype=torch.float32)
- for i in range(3):
- if True: #'flip' in augmentation_dict:
- if random.random() > 0.5:
- transform_t[i,i] *= -1
- if True: #'offset' in augmentation_dict:
- offset_float = 0.1
- random_float = (random.random() * 2 - 1)
- transform_t[3,i] = offset_float * random_float
- if True:
- angle_rad = random.random() * np.pi * 2
- s = np.sin(angle_rad)
- c = np.cos(angle_rad)
- rotation_t = torch.tensor([
- [c, -s, 0, 0],
- [s, c, 0, 0],
- [0, 0, 1, 0],
- [0, 0, 0, 1],
- ], dtype=torch.float32)
- transform_t @= rotation_t
- #print(inp.shape, transform_t[:3].unsqueeze(0).expand(inp.size(0), -1, -1).shape)
- affine_t = torch.nn.functional.affine_grid(
- transform_t[:3].unsqueeze(0).expand(inp.size(0), -1, -1).cuda(),
- inp.shape,
- align_corners=False,
- )
- augmented_chunk = torch.nn.functional.grid_sample(
- inp,
- affine_t,
- padding_mode='border',
- align_corners=False,
- )
- if False: #'noise' in augmentation_dict:
- noise_t = torch.randn_like(augmented_chunk)
- noise_t *= augmentation_dict['noise']
- augmented_chunk += noise_t
- return augmented_chunk
- class LunaModel(nn.Module):
- def __init__(self, in_channels=1, conv_channels=8):
- super().__init__()
- self.tail_batchnorm = nn.BatchNorm3d(1)
- self.block1 = LunaBlock(in_channels, conv_channels)
- self.block2 = LunaBlock(conv_channels, conv_channels * 2)
- self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4)
- self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8)
- # self.head_linear = nn.Linear(1152, 2)
- self.head_linear = nn.Linear(1152, 2)
- self.head_activation = nn.Softmax(dim=1)
- self._init_weights()
- # see also https://github.com/pytorch/pytorch/issues/18182
- def _init_weights(self):
- for m in self.modules():
- if type(m) in {
- nn.Linear,
- nn.Conv3d,
- nn.Conv2d,
- nn.ConvTranspose2d,
- nn.ConvTranspose3d,
- }:
- nn.init.kaiming_normal_(
- m.weight.data, a=0, mode='fan_out', nonlinearity='relu'
- )
- if m.bias is not None:
- fan_in, fan_out = \
- nn.init._calculate_fan_in_and_fan_out(m.weight.data)
- bound = 1 / math.sqrt(fan_out)
- nn.init.normal_(m.bias, -bound, bound)
- def forward(self, input_batch):
- bn_output = self.tail_batchnorm(input_batch)
- block_out = self.block1(bn_output)
- block_out = self.block2(block_out)
- block_out = self.block3(block_out)
- block_out = self.block4(block_out)
- conv_flat = block_out.view(
- block_out.size(0),
- -1,
- )
- linear_output = self.head_linear(conv_flat)
- return linear_output, self.head_activation(linear_output)
- class LunaBlock(nn.Module):
- def __init__(self, in_channels, conv_channels):
- super().__init__()
- self.conv1 = nn.Conv3d(
- in_channels, conv_channels, kernel_size=3, padding=1, bias=True
- )
- self.relu1 = nn.ReLU(inplace=True)
- self.conv2 = nn.Conv3d(
- conv_channels, conv_channels, kernel_size=3, padding=1, bias=True
- )
- self.relu2 = nn.ReLU(inplace=True)
- self.maxpool = nn.MaxPool3d(2, 2)
- def forward(self, input_batch):
- block_out = self.conv1(input_batch)
- block_out = self.relu1(block_out)
- block_out = self.conv2(block_out)
- block_out = self.relu2(block_out)
- return self.maxpool(block_out)
|