浏览代码

Removes outdated file.

Eli Stevens 6 年之前
父节点
当前提交
0bd75027dc
共有 1 个文件被更改,包括 0 次插入328 次删除
  1. 0 328
      p2ch10/model_segmentation.py

+ 0 - 328
p2ch10/model_segmentation.py

@@ -1,328 +0,0 @@
-import torch
-from torch import nn as nn
-
-from util.logconf import logging
-from util.unet import UNet
-
-log = logging.getLogger(__name__)
-# log.setLevel(logging.WARN)
-# log.setLevel(logging.INFO)
-log.setLevel(logging.DEBUG)
-
-# torch.backends.cudnn.enabled = False
-
-class UNetWrapper(nn.Module):
-    def __init__(self, **kwargs):
-        super().__init__()
-
-        self.batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
-        self.unet = UNet(**kwargs)
-        self.hardtanh = nn.Hardtanh(min_val=0, max_val=1)
-
-    def forward(self, input):
-        bn_output = self.batchnorm(input)
-        un_output = self.unet(bn_output)
-        ht_output = self.hardtanh(un_output)
-
-        return ht_output
-
-
-
-class Simple2dSegmentationModel(nn.Module):
-    def __init__(self, layers, in_channels, conv_channels, final_channels):
-        super().__init__()
-        self.layers = layers
-
-        self.in_channels = in_channels
-        self.conv_channels = conv_channels
-        self.final_channels = final_channels
-
-        layer_list = [
-            nn.Conv2d(self.in_channels, self.conv_channels, kernel_size=3, padding=1),
-            nn.BatchNorm2d(self.conv_channels),
-            # nn.GroupNorm(1, self.conv_channels),
-            # nn.ReLU(inplace=True),
-            nn.LeakyReLU(inplace=True),
-        ]
-
-        for i in range(self.layers):
-            layer_list.extend([
-                nn.Conv2d(self.conv_channels, self.conv_channels, kernel_size=3, padding=1),
-                nn.BatchNorm2d(self.conv_channels),
-                # nn.GroupNorm(1, self.conv_channels),
-                # nn.ReLU(inplace=True),
-                nn.LeakyReLU(inplace=True),
-            ])
-
-        layer_list.extend([
-            nn.Conv2d(self.conv_channels, self.final_channels, kernel_size=1, bias=True),
-            nn.Hardtanh(min_val=0, max_val=1),
-        ])
-
-        self.layer_seq = nn.Sequential(*layer_list)
-
-
-    def forward(self, in_data):
-        return self.layer_seq(in_data)
-
-
-class Dense2dSegmentationModel(nn.Module):
-    def __init__(self, layers, input_channels, conv_channels, bottleneck_channels, final_channels):
-        super().__init__()
-        self.layers = layers
-
-        self.input_channels = input_channels
-        self.conv_channels = conv_channels
-        self.bottleneck_channels = bottleneck_channels
-        self.final_channels = final_channels
-
-        self.layer_list = nn.ModuleList()
-
-        for i in range(layers):
-            self.layer_list.append(
-                Dense2dSegmentationBlock(
-                    input_channels + bottleneck_channels * i,
-                    conv_channels,
-                    bottleneck_channels,
-                )
-            )
-
-        self.layer_list.append(
-            Dense2dSegmentationBlock(
-                input_channels + bottleneck_channels * layers,
-                conv_channels,
-                bottleneck_channels,
-                final_channels,
-            )
-        )
-
-        self.htanh_layer = nn.Hardtanh(min_val=0, max_val=1)
-
-    def forward(self, input_tensor):
-        concat_list = [input_tensor]
-        for layer_block in self.layer_list:
-            layer_output = layer_block(torch.cat(concat_list, dim=1))
-            concat_list.append(layer_output)
-
-        return self.htanh_layer(concat_list[-1])
-
-
-class Dense2dSegmentationBlock(nn.Module):
-    def __init__(self, input_channels, conv_channels, bottleneck_channels, final_channels=None):
-        super().__init__()
-
-        self.input_channels = input_channels
-        self.conv_channels = conv_channels
-        self.bottleneck_channels = bottleneck_channels
-        self.final_channels = final_channels or bottleneck_channels
-
-        self.conv1_seq = nn.Sequential(
-            nn.Conv2d(self.input_channels, self.bottleneck_channels, kernel_size=1),
-            nn.Conv2d(self.bottleneck_channels, self.conv_channels, kernel_size=3, padding=1),
-            nn.Conv2d(self.conv_channels, self.bottleneck_channels, kernel_size=1),
-            # nn.BatchNorm2d(self.conv_channels),
-            nn.GroupNorm(1, self.bottleneck_channels),
-            # nn.ReLU(inplace=True),
-            nn.LeakyReLU(inplace=True),
-        )
-
-        self.conv2_seq = nn.Sequential(
-            nn.Conv2d(self.input_channels + self.bottleneck_channels, self.bottleneck_channels, kernel_size=1),
-            nn.Conv2d(self.bottleneck_channels, self.conv_channels, kernel_size=3, padding=1),
-            nn.Conv2d(self.conv_channels, self.final_channels, kernel_size=1),
-            # nn.BatchNorm2d(self.conv_channels),
-            nn.GroupNorm(1, self.final_channels),
-            # nn.ReLU(inplace=True),
-            nn.LeakyReLU(inplace=True),
-        )
-
-    def forward(self, input_tensor):
-        conv1_tensor = self.conv1_seq(input_tensor)
-        conv2_tensor = self.conv2_seq(torch.cat([input_tensor, conv1_tensor], dim=1))
-
-        return conv2_tensor
-
-
-class SegmentationModel(nn.Module):
-    def __init__(self, depth, in_channels, tail_channels=None, out_channels=None, final_channels=None):
-        super().__init__()
-        self.depth = depth
-
-        # self.in_size = in_size
-        # self.tailOut_size = in_size #self.in_size - 4
-        # self.headIn_size = in_size #None
-        # self.out_size = in_size #None
-
-        self.in_channels = in_channels
-        self.tailOut_channels = tail_channels or in_channels * 2
-        self.headIn_channels = None
-        self.out_channels = out_channels or self.tailOut_channels
-        self.final_channels = final_channels
-
-        # assert in_size % 2 == 0, repr([in_size, depth])
-
-        self.tail_seq = nn.Sequential(
-            nn.ReplicationPad3d(2),
-            nn.Conv3d(self.in_channels, self.tailOut_channels, 3),
-            nn.GroupNorm(1, self.tailOut_channels),
-            nn.ReLU(inplace=True),
-            nn.Conv3d(self.tailOut_channels, self.tailOut_channels, 3),
-            nn.GroupNorm(1, self.tailOut_channels),
-            nn.ReLU(inplace=True),
-        )
-
-        if depth:
-            self.downsample_layer = nn.MaxPool3d(kernel_size=2, stride=2)
-            self.child_layer = SegmentationModel(depth - 1, self.tailOut_channels)
-
-            self.headIn_channels = self.in_channels + self.tailOut_channels + self.child_layer.out_channels
-            # self.headIn_size = self.child_layer.out_size * 2
-            # self.out_size = self.headIn_size #- 4
-
-            # self.upsample_layer = nn.Upsample(scale_factor=2, mode='trilinear')
-        else:
-            self.downsample_layer = None
-            self.child_layer = None
-            # self.upsample_layer = None
-
-            self.headIn_channels = self.in_channels + self.tailOut_channels
-            # self.headIn_size = self.tailOut_size
-            # self.out_size = self.headIn_size #- 4
-
-        self.head_seq = nn.Sequential(
-            nn.ReplicationPad3d(2),
-            nn.Conv3d(self.headIn_channels, self.out_channels, 3),
-            nn.GroupNorm(1, self.out_channels),
-            nn.ReLU(inplace=True),
-            nn.Conv3d(self.out_channels, self.out_channels, 3),
-            nn.GroupNorm(1, self.out_channels),
-            nn.ReLU(inplace=True),
-        )
-
-        if self.final_channels:
-            self.final_seq = nn.Sequential(
-                nn.ReplicationPad3d(1),
-                nn.Conv3d(self.out_channels, self.final_channels, 1),
-            )
-        else:
-            self.final_seq = None
-
-    def forward(self, in_data):
-
-        assert in_data.is_contiguous()
-
-        try:
-            tail_out = self.tail_seq(in_data)
-        except:
-            log.debug([in_data.size()])
-            raise
-
-        if self.downsample_layer:
-            down_out = self.downsample_layer(tail_out)
-            child_out = self.child_layer(down_out)
-            # up_out = self.upsample_layer(child_out)
-
-            up_out = nn.functional.interpolate(child_out, scale_factor=2, mode='trilinear')
-
-            # crop_int = (tail_out.size(-1) - up_out.size(-1)) // 2
-            # crop_out = tail_out[:, :, crop_int:-crop_int, crop_int:-crop_int, crop_int:-crop_int]
-            # combined_out = torch.cat([crop_out, up_out], 1)
-
-            combined_out = torch.cat([in_data, tail_out, up_out], 1)
-        else:
-            combined_out = torch.cat([in_data, tail_out], 1)
-
-        head_out = self.head_seq(combined_out)
-
-        if self.final_seq:
-            final_out = self.final_seq(head_out)
-            return final_out
-        else:
-            return head_out
-
-
-class DenseSegmentationModel(nn.Module):
-    def __init__(self, depth, in_channels, conv_channels, final_channels=None):
-        super().__init__()
-        self.depth = depth
-
-        self.in_channels = in_channels
-        self.conv_channels = conv_channels
-        self.final_channels = final_channels
-
-        self.convA_seq = nn.Sequential(
-            nn.Conv3d(self.in_channels, self.conv_channels // 4, 1),
-            nn.ReplicationPad3d(1),
-            nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
-            nn.BatchNorm3d(self.conv_channels),
-            nn.ReLU(inplace=True),
-        )
-
-        self.convB_seq = nn.Sequential(
-            nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
-            nn.ReplicationPad3d(1),
-            nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
-            nn.BatchNorm3d(self.conv_channels),
-            nn.ReLU(inplace=True),
-        )
-
-        if self.depth:
-            self.downsample_layer = nn.MaxPool3d(kernel_size=2, stride=2)
-            self.child_layer = SegmentationModel(depth - 1, self.conv_channels, self.conv_channels * 2)
-            self.upsample_layer = nn.Upsample(scale_factor=2, mode='trilinear')
-
-            self.convC_seq = nn.Sequential(
-                nn.Conv3d(self.in_channels + self.conv_channels * 3, self.conv_channels // 4, 1),
-                nn.ReplicationPad3d(1),
-                nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
-                nn.BatchNorm3d(self.conv_channels),
-                nn.ReLU(inplace=True),
-            )
-        else:
-            self.downsample_layer = None
-            self.child_layer = None
-            self.upsample_layer = None
-
-            self.convC_seq = nn.Sequential(
-                nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
-                nn.ReplicationPad3d(1),
-                nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
-                nn.BatchNorm3d(self.conv_channels),
-                nn.ReLU(inplace=True),
-            )
-
-        self.convD_seq = nn.Sequential(
-            nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
-            nn.ReplicationPad3d(1),
-            nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
-            nn.BatchNorm3d(self.conv_channels),
-            nn.ReLU(inplace=True),
-        )
-
-        if self.final_channels:
-            self.final_seq = nn.Sequential(
-                # nn.ReplicationPad3d(1),
-                nn.Conv3d(self.conv_channels, self.final_channels, 1),
-            )
-        else:
-            self.final_seq = None
-
-    def forward(self, data_in):
-        a_out = self.convA_seq(data_in)
-        b_out = self.convB_seq(torch.cat([data_in, a_out], 1))
-
-        if self.downsample_layer:
-            down_out = self.downsample_layer(b_out)
-            child_out = self.child_layer(down_out)
-            up_out = self.upsample_layer(child_out)
-
-            c_out = self.convC_seq(torch.cat([data_in, b_out, up_out], 1))
-        else:
-            c_out = self.convC_seq(torch.cat([data_in, b_out], 1))
-
-        d_out = self.convD_seq(torch.cat([data_in, c_out], 1))
-
-        if self.final_seq:
-            return self.final_seq(d_out)
-        else:
-            return d_out