|
|
@@ -1,328 +0,0 @@
|
|
|
-import torch
|
|
|
-from torch import nn as nn
|
|
|
-
|
|
|
-from util.logconf import logging
|
|
|
-from util.unet import UNet
|
|
|
-
|
|
|
-log = logging.getLogger(__name__)
|
|
|
-# log.setLevel(logging.WARN)
|
|
|
-# log.setLevel(logging.INFO)
|
|
|
-log.setLevel(logging.DEBUG)
|
|
|
-
|
|
|
-# torch.backends.cudnn.enabled = False
|
|
|
-
|
|
|
-class UNetWrapper(nn.Module):
|
|
|
- def __init__(self, **kwargs):
|
|
|
- super().__init__()
|
|
|
-
|
|
|
- self.batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
|
|
|
- self.unet = UNet(**kwargs)
|
|
|
- self.hardtanh = nn.Hardtanh(min_val=0, max_val=1)
|
|
|
-
|
|
|
- def forward(self, input):
|
|
|
- bn_output = self.batchnorm(input)
|
|
|
- un_output = self.unet(bn_output)
|
|
|
- ht_output = self.hardtanh(un_output)
|
|
|
-
|
|
|
- return ht_output
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-class Simple2dSegmentationModel(nn.Module):
|
|
|
- def __init__(self, layers, in_channels, conv_channels, final_channels):
|
|
|
- super().__init__()
|
|
|
- self.layers = layers
|
|
|
-
|
|
|
- self.in_channels = in_channels
|
|
|
- self.conv_channels = conv_channels
|
|
|
- self.final_channels = final_channels
|
|
|
-
|
|
|
- layer_list = [
|
|
|
- nn.Conv2d(self.in_channels, self.conv_channels, kernel_size=3, padding=1),
|
|
|
- nn.BatchNorm2d(self.conv_channels),
|
|
|
- # nn.GroupNorm(1, self.conv_channels),
|
|
|
- # nn.ReLU(inplace=True),
|
|
|
- nn.LeakyReLU(inplace=True),
|
|
|
- ]
|
|
|
-
|
|
|
- for i in range(self.layers):
|
|
|
- layer_list.extend([
|
|
|
- nn.Conv2d(self.conv_channels, self.conv_channels, kernel_size=3, padding=1),
|
|
|
- nn.BatchNorm2d(self.conv_channels),
|
|
|
- # nn.GroupNorm(1, self.conv_channels),
|
|
|
- # nn.ReLU(inplace=True),
|
|
|
- nn.LeakyReLU(inplace=True),
|
|
|
- ])
|
|
|
-
|
|
|
- layer_list.extend([
|
|
|
- nn.Conv2d(self.conv_channels, self.final_channels, kernel_size=1, bias=True),
|
|
|
- nn.Hardtanh(min_val=0, max_val=1),
|
|
|
- ])
|
|
|
-
|
|
|
- self.layer_seq = nn.Sequential(*layer_list)
|
|
|
-
|
|
|
-
|
|
|
- def forward(self, in_data):
|
|
|
- return self.layer_seq(in_data)
|
|
|
-
|
|
|
-
|
|
|
-class Dense2dSegmentationModel(nn.Module):
|
|
|
- def __init__(self, layers, input_channels, conv_channels, bottleneck_channels, final_channels):
|
|
|
- super().__init__()
|
|
|
- self.layers = layers
|
|
|
-
|
|
|
- self.input_channels = input_channels
|
|
|
- self.conv_channels = conv_channels
|
|
|
- self.bottleneck_channels = bottleneck_channels
|
|
|
- self.final_channels = final_channels
|
|
|
-
|
|
|
- self.layer_list = nn.ModuleList()
|
|
|
-
|
|
|
- for i in range(layers):
|
|
|
- self.layer_list.append(
|
|
|
- Dense2dSegmentationBlock(
|
|
|
- input_channels + bottleneck_channels * i,
|
|
|
- conv_channels,
|
|
|
- bottleneck_channels,
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- self.layer_list.append(
|
|
|
- Dense2dSegmentationBlock(
|
|
|
- input_channels + bottleneck_channels * layers,
|
|
|
- conv_channels,
|
|
|
- bottleneck_channels,
|
|
|
- final_channels,
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- self.htanh_layer = nn.Hardtanh(min_val=0, max_val=1)
|
|
|
-
|
|
|
- def forward(self, input_tensor):
|
|
|
- concat_list = [input_tensor]
|
|
|
- for layer_block in self.layer_list:
|
|
|
- layer_output = layer_block(torch.cat(concat_list, dim=1))
|
|
|
- concat_list.append(layer_output)
|
|
|
-
|
|
|
- return self.htanh_layer(concat_list[-1])
|
|
|
-
|
|
|
-
|
|
|
-class Dense2dSegmentationBlock(nn.Module):
|
|
|
- def __init__(self, input_channels, conv_channels, bottleneck_channels, final_channels=None):
|
|
|
- super().__init__()
|
|
|
-
|
|
|
- self.input_channels = input_channels
|
|
|
- self.conv_channels = conv_channels
|
|
|
- self.bottleneck_channels = bottleneck_channels
|
|
|
- self.final_channels = final_channels or bottleneck_channels
|
|
|
-
|
|
|
- self.conv1_seq = nn.Sequential(
|
|
|
- nn.Conv2d(self.input_channels, self.bottleneck_channels, kernel_size=1),
|
|
|
- nn.Conv2d(self.bottleneck_channels, self.conv_channels, kernel_size=3, padding=1),
|
|
|
- nn.Conv2d(self.conv_channels, self.bottleneck_channels, kernel_size=1),
|
|
|
- # nn.BatchNorm2d(self.conv_channels),
|
|
|
- nn.GroupNorm(1, self.bottleneck_channels),
|
|
|
- # nn.ReLU(inplace=True),
|
|
|
- nn.LeakyReLU(inplace=True),
|
|
|
- )
|
|
|
-
|
|
|
- self.conv2_seq = nn.Sequential(
|
|
|
- nn.Conv2d(self.input_channels + self.bottleneck_channels, self.bottleneck_channels, kernel_size=1),
|
|
|
- nn.Conv2d(self.bottleneck_channels, self.conv_channels, kernel_size=3, padding=1),
|
|
|
- nn.Conv2d(self.conv_channels, self.final_channels, kernel_size=1),
|
|
|
- # nn.BatchNorm2d(self.conv_channels),
|
|
|
- nn.GroupNorm(1, self.final_channels),
|
|
|
- # nn.ReLU(inplace=True),
|
|
|
- nn.LeakyReLU(inplace=True),
|
|
|
- )
|
|
|
-
|
|
|
- def forward(self, input_tensor):
|
|
|
- conv1_tensor = self.conv1_seq(input_tensor)
|
|
|
- conv2_tensor = self.conv2_seq(torch.cat([input_tensor, conv1_tensor], dim=1))
|
|
|
-
|
|
|
- return conv2_tensor
|
|
|
-
|
|
|
-
|
|
|
-class SegmentationModel(nn.Module):
|
|
|
- def __init__(self, depth, in_channels, tail_channels=None, out_channels=None, final_channels=None):
|
|
|
- super().__init__()
|
|
|
- self.depth = depth
|
|
|
-
|
|
|
- # self.in_size = in_size
|
|
|
- # self.tailOut_size = in_size #self.in_size - 4
|
|
|
- # self.headIn_size = in_size #None
|
|
|
- # self.out_size = in_size #None
|
|
|
-
|
|
|
- self.in_channels = in_channels
|
|
|
- self.tailOut_channels = tail_channels or in_channels * 2
|
|
|
- self.headIn_channels = None
|
|
|
- self.out_channels = out_channels or self.tailOut_channels
|
|
|
- self.final_channels = final_channels
|
|
|
-
|
|
|
- # assert in_size % 2 == 0, repr([in_size, depth])
|
|
|
-
|
|
|
- self.tail_seq = nn.Sequential(
|
|
|
- nn.ReplicationPad3d(2),
|
|
|
- nn.Conv3d(self.in_channels, self.tailOut_channels, 3),
|
|
|
- nn.GroupNorm(1, self.tailOut_channels),
|
|
|
- nn.ReLU(inplace=True),
|
|
|
- nn.Conv3d(self.tailOut_channels, self.tailOut_channels, 3),
|
|
|
- nn.GroupNorm(1, self.tailOut_channels),
|
|
|
- nn.ReLU(inplace=True),
|
|
|
- )
|
|
|
-
|
|
|
- if depth:
|
|
|
- self.downsample_layer = nn.MaxPool3d(kernel_size=2, stride=2)
|
|
|
- self.child_layer = SegmentationModel(depth - 1, self.tailOut_channels)
|
|
|
-
|
|
|
- self.headIn_channels = self.in_channels + self.tailOut_channels + self.child_layer.out_channels
|
|
|
- # self.headIn_size = self.child_layer.out_size * 2
|
|
|
- # self.out_size = self.headIn_size #- 4
|
|
|
-
|
|
|
- # self.upsample_layer = nn.Upsample(scale_factor=2, mode='trilinear')
|
|
|
- else:
|
|
|
- self.downsample_layer = None
|
|
|
- self.child_layer = None
|
|
|
- # self.upsample_layer = None
|
|
|
-
|
|
|
- self.headIn_channels = self.in_channels + self.tailOut_channels
|
|
|
- # self.headIn_size = self.tailOut_size
|
|
|
- # self.out_size = self.headIn_size #- 4
|
|
|
-
|
|
|
- self.head_seq = nn.Sequential(
|
|
|
- nn.ReplicationPad3d(2),
|
|
|
- nn.Conv3d(self.headIn_channels, self.out_channels, 3),
|
|
|
- nn.GroupNorm(1, self.out_channels),
|
|
|
- nn.ReLU(inplace=True),
|
|
|
- nn.Conv3d(self.out_channels, self.out_channels, 3),
|
|
|
- nn.GroupNorm(1, self.out_channels),
|
|
|
- nn.ReLU(inplace=True),
|
|
|
- )
|
|
|
-
|
|
|
- if self.final_channels:
|
|
|
- self.final_seq = nn.Sequential(
|
|
|
- nn.ReplicationPad3d(1),
|
|
|
- nn.Conv3d(self.out_channels, self.final_channels, 1),
|
|
|
- )
|
|
|
- else:
|
|
|
- self.final_seq = None
|
|
|
-
|
|
|
- def forward(self, in_data):
|
|
|
-
|
|
|
- assert in_data.is_contiguous()
|
|
|
-
|
|
|
- try:
|
|
|
- tail_out = self.tail_seq(in_data)
|
|
|
- except:
|
|
|
- log.debug([in_data.size()])
|
|
|
- raise
|
|
|
-
|
|
|
- if self.downsample_layer:
|
|
|
- down_out = self.downsample_layer(tail_out)
|
|
|
- child_out = self.child_layer(down_out)
|
|
|
- # up_out = self.upsample_layer(child_out)
|
|
|
-
|
|
|
- up_out = nn.functional.interpolate(child_out, scale_factor=2, mode='trilinear')
|
|
|
-
|
|
|
- # crop_int = (tail_out.size(-1) - up_out.size(-1)) // 2
|
|
|
- # crop_out = tail_out[:, :, crop_int:-crop_int, crop_int:-crop_int, crop_int:-crop_int]
|
|
|
- # combined_out = torch.cat([crop_out, up_out], 1)
|
|
|
-
|
|
|
- combined_out = torch.cat([in_data, tail_out, up_out], 1)
|
|
|
- else:
|
|
|
- combined_out = torch.cat([in_data, tail_out], 1)
|
|
|
-
|
|
|
- head_out = self.head_seq(combined_out)
|
|
|
-
|
|
|
- if self.final_seq:
|
|
|
- final_out = self.final_seq(head_out)
|
|
|
- return final_out
|
|
|
- else:
|
|
|
- return head_out
|
|
|
-
|
|
|
-
|
|
|
-class DenseSegmentationModel(nn.Module):
|
|
|
- def __init__(self, depth, in_channels, conv_channels, final_channels=None):
|
|
|
- super().__init__()
|
|
|
- self.depth = depth
|
|
|
-
|
|
|
- self.in_channels = in_channels
|
|
|
- self.conv_channels = conv_channels
|
|
|
- self.final_channels = final_channels
|
|
|
-
|
|
|
- self.convA_seq = nn.Sequential(
|
|
|
- nn.Conv3d(self.in_channels, self.conv_channels // 4, 1),
|
|
|
- nn.ReplicationPad3d(1),
|
|
|
- nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
|
|
|
- nn.BatchNorm3d(self.conv_channels),
|
|
|
- nn.ReLU(inplace=True),
|
|
|
- )
|
|
|
-
|
|
|
- self.convB_seq = nn.Sequential(
|
|
|
- nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
|
|
|
- nn.ReplicationPad3d(1),
|
|
|
- nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
|
|
|
- nn.BatchNorm3d(self.conv_channels),
|
|
|
- nn.ReLU(inplace=True),
|
|
|
- )
|
|
|
-
|
|
|
- if self.depth:
|
|
|
- self.downsample_layer = nn.MaxPool3d(kernel_size=2, stride=2)
|
|
|
- self.child_layer = SegmentationModel(depth - 1, self.conv_channels, self.conv_channels * 2)
|
|
|
- self.upsample_layer = nn.Upsample(scale_factor=2, mode='trilinear')
|
|
|
-
|
|
|
- self.convC_seq = nn.Sequential(
|
|
|
- nn.Conv3d(self.in_channels + self.conv_channels * 3, self.conv_channels // 4, 1),
|
|
|
- nn.ReplicationPad3d(1),
|
|
|
- nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
|
|
|
- nn.BatchNorm3d(self.conv_channels),
|
|
|
- nn.ReLU(inplace=True),
|
|
|
- )
|
|
|
- else:
|
|
|
- self.downsample_layer = None
|
|
|
- self.child_layer = None
|
|
|
- self.upsample_layer = None
|
|
|
-
|
|
|
- self.convC_seq = nn.Sequential(
|
|
|
- nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
|
|
|
- nn.ReplicationPad3d(1),
|
|
|
- nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
|
|
|
- nn.BatchNorm3d(self.conv_channels),
|
|
|
- nn.ReLU(inplace=True),
|
|
|
- )
|
|
|
-
|
|
|
- self.convD_seq = nn.Sequential(
|
|
|
- nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
|
|
|
- nn.ReplicationPad3d(1),
|
|
|
- nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
|
|
|
- nn.BatchNorm3d(self.conv_channels),
|
|
|
- nn.ReLU(inplace=True),
|
|
|
- )
|
|
|
-
|
|
|
- if self.final_channels:
|
|
|
- self.final_seq = nn.Sequential(
|
|
|
- # nn.ReplicationPad3d(1),
|
|
|
- nn.Conv3d(self.conv_channels, self.final_channels, 1),
|
|
|
- )
|
|
|
- else:
|
|
|
- self.final_seq = None
|
|
|
-
|
|
|
- def forward(self, data_in):
|
|
|
- a_out = self.convA_seq(data_in)
|
|
|
- b_out = self.convB_seq(torch.cat([data_in, a_out], 1))
|
|
|
-
|
|
|
- if self.downsample_layer:
|
|
|
- down_out = self.downsample_layer(b_out)
|
|
|
- child_out = self.child_layer(down_out)
|
|
|
- up_out = self.upsample_layer(child_out)
|
|
|
-
|
|
|
- c_out = self.convC_seq(torch.cat([data_in, b_out, up_out], 1))
|
|
|
- else:
|
|
|
- c_out = self.convC_seq(torch.cat([data_in, b_out], 1))
|
|
|
-
|
|
|
- d_out = self.convD_seq(torch.cat([data_in, c_out], 1))
|
|
|
-
|
|
|
- if self.final_seq:
|
|
|
- return self.final_seq(d_out)
|
|
|
- else:
|
|
|
- return d_out
|