| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- # From https://github.com/jvanvugt/pytorch-unet
- # https://raw.githubusercontent.com/jvanvugt/pytorch-unet/master/unet.py
- # MIT License
- #
- # Copyright (c) 2018 Joris
- #
- # Permission is hereby granted, free of charge, to any person obtaining a copy
- # of this software and associated documentation files (the "Software"), to deal
- # in the Software without restriction, including without limitation the rights
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- # copies of the Software, and to permit persons to whom the Software is
- # furnished to do so, subject to the following conditions:
- #
- # The above copyright notice and this permission notice shall be included in all
- # copies or substantial portions of the Software.
- #
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- # SOFTWARE.
- # Adapted from https://discuss.pytorch.org/t/unet-implementation/426
- import torch
- from torch import nn
- import torch.nn.functional as F
- class UNet(nn.Module):
- def __init__(self, in_channels=1, n_classes=2, depth=5, wf=6, padding=False,
- batch_norm=False, up_mode='upconv'):
- """
- Implementation of
- U-Net: Convolutional Networks for Biomedical Image Segmentation
- (Ronneberger et al., 2015)
- https://arxiv.org/abs/1505.04597
- Using the default arguments will yield the exact version used
- in the original paper
- Args:
- in_channels (int): number of input channels
- n_classes (int): number of output channels
- depth (int): depth of the network
- wf (int): number of filters in the first layer is 2**wf
- padding (bool): if True, apply padding such that the input shape
- is the same as the output.
- This may introduce artifacts
- batch_norm (bool): Use BatchNorm after layers with an
- activation function
- up_mode (str): one of 'upconv' or 'upsample'.
- 'upconv' will use transposed convolutions for
- learned upsampling.
- 'upsample' will use bilinear upsampling.
- """
- super(UNet, self).__init__()
- assert up_mode in ('upconv', 'upsample')
- self.padding = padding
- self.depth = depth
- prev_channels = in_channels
- self.down_path = nn.ModuleList()
- for i in range(depth):
- self.down_path.append(UNetConvBlock(prev_channels, 2**(wf+i),
- padding, batch_norm))
- prev_channels = 2**(wf+i)
- self.up_path = nn.ModuleList()
- for i in reversed(range(depth - 1)):
- self.up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode,
- padding, batch_norm))
- prev_channels = 2**(wf+i)
- self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)
- def forward(self, x):
- blocks = []
- for i, down in enumerate(self.down_path):
- x = down(x)
- if i != len(self.down_path)-1:
- blocks.append(x)
- x = F.avg_pool2d(x, 2)
- for i, up in enumerate(self.up_path):
- x = up(x, blocks[-i-1])
- return self.last(x)
- class UNetConvBlock(nn.Module):
- def __init__(self, in_size, out_size, padding, batch_norm):
- super(UNetConvBlock, self).__init__()
- block = []
- block.append(nn.Conv2d(in_size, out_size, kernel_size=3,
- padding=int(padding)))
- block.append(nn.ReLU())
- # block.append(nn.LeakyReLU())
- if batch_norm:
- block.append(nn.BatchNorm2d(out_size))
- block.append(nn.Conv2d(out_size, out_size, kernel_size=3,
- padding=int(padding)))
- block.append(nn.ReLU())
- # block.append(nn.LeakyReLU())
- if batch_norm:
- block.append(nn.BatchNorm2d(out_size))
- self.block = nn.Sequential(*block)
- def forward(self, x):
- out = self.block(x)
- return out
- class UNetUpBlock(nn.Module):
- def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
- super(UNetUpBlock, self).__init__()
- if up_mode == 'upconv':
- self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2,
- stride=2)
- elif up_mode == 'upsample':
- self.up = nn.Sequential(nn.Upsample(mode='bilinear', scale_factor=2),
- nn.Conv2d(in_size, out_size, kernel_size=1))
- self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)
- def center_crop(self, layer, target_size):
- _, _, layer_height, layer_width = layer.size()
- diff_y = (layer_height - target_size[0]) // 2
- diff_x = (layer_width - target_size[1]) // 2
- return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])]
- def forward(self, x, bridge):
- up = self.up(x)
- crop1 = self.center_crop(bridge, up.shape[2:])
- out = torch.cat([up, crop1], 1)
- out = self.conv_block(out)
- return out
|