| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- import torch
- from torch import nn as nn
- from util.logconf import logging
- from util.unet import UNet
- log = logging.getLogger(__name__)
- # log.setLevel(logging.WARN)
- # log.setLevel(logging.INFO)
- log.setLevel(logging.DEBUG)
- # torch.backends.cudnn.enabled = False
- class UNetWrapper(nn.Module):
- def __init__(self, **kwargs):
- super().__init__()
- self.batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
- self.unet = UNet(**kwargs)
- self.hardtanh = nn.Hardtanh(min_val=0, max_val=1)
- def forward(self, input):
- bn_output = self.batchnorm(input)
- un_output = self.unet(bn_output)
- ht_output = self.hardtanh(un_output)
- return ht_output
- class Simple2dSegmentationModel(nn.Module):
- def __init__(self, layers, in_channels, conv_channels, final_channels):
- super().__init__()
- self.layers = layers
- self.in_channels = in_channels
- self.conv_channels = conv_channels
- self.final_channels = final_channels
- layer_list = [
- nn.Conv2d(self.in_channels, self.conv_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(self.conv_channels),
- # nn.GroupNorm(1, self.conv_channels),
- # nn.ReLU(inplace=True),
- nn.LeakyReLU(inplace=True),
- ]
- for i in range(self.layers):
- layer_list.extend([
- nn.Conv2d(self.conv_channels, self.conv_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(self.conv_channels),
- # nn.GroupNorm(1, self.conv_channels),
- # nn.ReLU(inplace=True),
- nn.LeakyReLU(inplace=True),
- ])
- layer_list.extend([
- nn.Conv2d(self.conv_channels, self.final_channels, kernel_size=1, bias=True),
- nn.Hardtanh(min_val=0, max_val=1),
- ])
- self.layer_seq = nn.Sequential(*layer_list)
- def forward(self, in_data):
- return self.layer_seq(in_data)
- class Dense2dSegmentationModel(nn.Module):
- def __init__(self, layers, input_channels, conv_channels, bottleneck_channels, final_channels):
- super().__init__()
- self.layers = layers
- self.input_channels = input_channels
- self.conv_channels = conv_channels
- self.bottleneck_channels = bottleneck_channels
- self.final_channels = final_channels
- self.layer_list = nn.ModuleList()
- for i in range(layers):
- self.layer_list.append(
- Dense2dSegmentationBlock(
- input_channels + bottleneck_channels * i,
- conv_channels,
- bottleneck_channels,
- )
- )
- self.layer_list.append(
- Dense2dSegmentationBlock(
- input_channels + bottleneck_channels * layers,
- conv_channels,
- bottleneck_channels,
- final_channels,
- )
- )
- self.htanh_layer = nn.Hardtanh(min_val=0, max_val=1)
- def forward(self, input_tensor):
- concat_list = [input_tensor]
- for layer_block in self.layer_list:
- layer_output = layer_block(torch.cat(concat_list, dim=1))
- concat_list.append(layer_output)
- return self.htanh_layer(concat_list[-1])
- class Dense2dSegmentationBlock(nn.Module):
- def __init__(self, input_channels, conv_channels, bottleneck_channels, final_channels=None):
- super().__init__()
- self.input_channels = input_channels
- self.conv_channels = conv_channels
- self.bottleneck_channels = bottleneck_channels
- self.final_channels = final_channels or bottleneck_channels
- self.conv1_seq = nn.Sequential(
- nn.Conv2d(self.input_channels, self.bottleneck_channels, kernel_size=1),
- nn.Conv2d(self.bottleneck_channels, self.conv_channels, kernel_size=3, padding=1),
- nn.Conv2d(self.conv_channels, self.bottleneck_channels, kernel_size=1),
- # nn.BatchNorm2d(self.conv_channels),
- nn.GroupNorm(1, self.bottleneck_channels),
- # nn.ReLU(inplace=True),
- nn.LeakyReLU(inplace=True),
- )
- self.conv2_seq = nn.Sequential(
- nn.Conv2d(self.input_channels + self.bottleneck_channels, self.bottleneck_channels, kernel_size=1),
- nn.Conv2d(self.bottleneck_channels, self.conv_channels, kernel_size=3, padding=1),
- nn.Conv2d(self.conv_channels, self.final_channels, kernel_size=1),
- # nn.BatchNorm2d(self.conv_channels),
- nn.GroupNorm(1, self.final_channels),
- # nn.ReLU(inplace=True),
- nn.LeakyReLU(inplace=True),
- )
- def forward(self, input_tensor):
- conv1_tensor = self.conv1_seq(input_tensor)
- conv2_tensor = self.conv2_seq(torch.cat([input_tensor, conv1_tensor], dim=1))
- return conv2_tensor
- class SegmentationModel(nn.Module):
- def __init__(self, depth, in_channels, tail_channels=None, out_channels=None, final_channels=None):
- super().__init__()
- self.depth = depth
- # self.in_size = in_size
- # self.tailOut_size = in_size #self.in_size - 4
- # self.headIn_size = in_size #None
- # self.out_size = in_size #None
- self.in_channels = in_channels
- self.tailOut_channels = tail_channels or in_channels * 2
- self.headIn_channels = None
- self.out_channels = out_channels or self.tailOut_channels
- self.final_channels = final_channels
- # assert in_size % 2 == 0, repr([in_size, depth])
- self.tail_seq = nn.Sequential(
- nn.ReplicationPad3d(2),
- nn.Conv3d(self.in_channels, self.tailOut_channels, 3),
- nn.GroupNorm(1, self.tailOut_channels),
- nn.ReLU(inplace=True),
- nn.Conv3d(self.tailOut_channels, self.tailOut_channels, 3),
- nn.GroupNorm(1, self.tailOut_channels),
- nn.ReLU(inplace=True),
- )
- if depth:
- self.downsample_layer = nn.MaxPool3d(kernel_size=2, stride=2)
- self.child_layer = SegmentationModel(depth - 1, self.tailOut_channels)
- self.headIn_channels = self.in_channels + self.tailOut_channels + self.child_layer.out_channels
- # self.headIn_size = self.child_layer.out_size * 2
- # self.out_size = self.headIn_size #- 4
- # self.upsample_layer = nn.Upsample(scale_factor=2, mode='trilinear')
- else:
- self.downsample_layer = None
- self.child_layer = None
- # self.upsample_layer = None
- self.headIn_channels = self.in_channels + self.tailOut_channels
- # self.headIn_size = self.tailOut_size
- # self.out_size = self.headIn_size #- 4
- self.head_seq = nn.Sequential(
- nn.ReplicationPad3d(2),
- nn.Conv3d(self.headIn_channels, self.out_channels, 3),
- nn.GroupNorm(1, self.out_channels),
- nn.ReLU(inplace=True),
- nn.Conv3d(self.out_channels, self.out_channels, 3),
- nn.GroupNorm(1, self.out_channels),
- nn.ReLU(inplace=True),
- )
- if self.final_channels:
- self.final_seq = nn.Sequential(
- nn.ReplicationPad3d(1),
- nn.Conv3d(self.out_channels, self.final_channels, 1),
- )
- else:
- self.final_seq = None
- def forward(self, in_data):
- assert in_data.is_contiguous()
- try:
- tail_out = self.tail_seq(in_data)
- except:
- log.debug([in_data.size()])
- raise
- if self.downsample_layer:
- down_out = self.downsample_layer(tail_out)
- child_out = self.child_layer(down_out)
- # up_out = self.upsample_layer(child_out)
- up_out = nn.functional.interpolate(child_out, scale_factor=2, mode='trilinear')
- # crop_int = (tail_out.size(-1) - up_out.size(-1)) // 2
- # crop_out = tail_out[:, :, crop_int:-crop_int, crop_int:-crop_int, crop_int:-crop_int]
- # combined_out = torch.cat([crop_out, up_out], 1)
- combined_out = torch.cat([in_data, tail_out, up_out], 1)
- else:
- combined_out = torch.cat([in_data, tail_out], 1)
- head_out = self.head_seq(combined_out)
- if self.final_seq:
- final_out = self.final_seq(head_out)
- return final_out
- else:
- return head_out
- class DenseSegmentationModel(nn.Module):
- def __init__(self, depth, in_channels, conv_channels, final_channels=None):
- super().__init__()
- self.depth = depth
- self.in_channels = in_channels
- self.conv_channels = conv_channels
- self.final_channels = final_channels
- self.convA_seq = nn.Sequential(
- nn.Conv3d(self.in_channels, self.conv_channels // 4, 1),
- nn.ReplicationPad3d(1),
- nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
- nn.BatchNorm3d(self.conv_channels),
- nn.ReLU(inplace=True),
- )
- self.convB_seq = nn.Sequential(
- nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
- nn.ReplicationPad3d(1),
- nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
- nn.BatchNorm3d(self.conv_channels),
- nn.ReLU(inplace=True),
- )
- if self.depth:
- self.downsample_layer = nn.MaxPool3d(kernel_size=2, stride=2)
- self.child_layer = SegmentationModel(depth - 1, self.conv_channels, self.conv_channels * 2)
- self.upsample_layer = nn.Upsample(scale_factor=2, mode='trilinear')
- self.convC_seq = nn.Sequential(
- nn.Conv3d(self.in_channels + self.conv_channels * 3, self.conv_channels // 4, 1),
- nn.ReplicationPad3d(1),
- nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
- nn.BatchNorm3d(self.conv_channels),
- nn.ReLU(inplace=True),
- )
- else:
- self.downsample_layer = None
- self.child_layer = None
- self.upsample_layer = None
- self.convC_seq = nn.Sequential(
- nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
- nn.ReplicationPad3d(1),
- nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
- nn.BatchNorm3d(self.conv_channels),
- nn.ReLU(inplace=True),
- )
- self.convD_seq = nn.Sequential(
- nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
- nn.ReplicationPad3d(1),
- nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
- nn.BatchNorm3d(self.conv_channels),
- nn.ReLU(inplace=True),
- )
- if self.final_channels:
- self.final_seq = nn.Sequential(
- # nn.ReplicationPad3d(1),
- nn.Conv3d(self.conv_channels, self.final_channels, 1),
- )
- else:
- self.final_seq = None
- def forward(self, data_in):
- a_out = self.convA_seq(data_in)
- b_out = self.convB_seq(torch.cat([data_in, a_out], 1))
- if self.downsample_layer:
- down_out = self.downsample_layer(b_out)
- child_out = self.child_layer(down_out)
- up_out = self.upsample_layer(child_out)
- c_out = self.convC_seq(torch.cat([data_in, b_out, up_out], 1))
- else:
- c_out = self.convC_seq(torch.cat([data_in, b_out], 1))
- d_out = self.convD_seq(torch.cat([data_in, c_out], 1))
- if self.final_seq:
- return self.final_seq(d_out)
- else:
- return d_out
|