unet.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # From https://github.com/jvanvugt/pytorch-unet
  2. # https://raw.githubusercontent.com/jvanvugt/pytorch-unet/master/unet.py
  3. # MIT License
  4. #
  5. # Copyright (c) 2018 Joris
  6. #
  7. # Permission is hereby granted, free of charge, to any person obtaining a copy
  8. # of this software and associated documentation files (the "Software"), to deal
  9. # in the Software without restriction, including without limitation the rights
  10. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  11. # copies of the Software, and to permit persons to whom the Software is
  12. # furnished to do so, subject to the following conditions:
  13. #
  14. # The above copyright notice and this permission notice shall be included in all
  15. # copies or substantial portions of the Software.
  16. #
  17. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  18. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  19. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  20. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  21. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  22. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  23. # SOFTWARE.
  24. # Adapted from https://discuss.pytorch.org/t/unet-implementation/426
  25. import torch
  26. from torch import nn
  27. import torch.nn.functional as F
  28. class UNet(nn.Module):
  29. def __init__(self, in_channels=1, n_classes=2, depth=5, wf=6, padding=False,
  30. batch_norm=False, up_mode='upconv'):
  31. """
  32. Implementation of
  33. U-Net: Convolutional Networks for Biomedical Image Segmentation
  34. (Ronneberger et al., 2015)
  35. https://arxiv.org/abs/1505.04597
  36. Using the default arguments will yield the exact version used
  37. in the original paper
  38. Args:
  39. in_channels (int): number of input channels
  40. n_classes (int): number of output channels
  41. depth (int): depth of the network
  42. wf (int): number of filters in the first layer is 2**wf
  43. padding (bool): if True, apply padding such that the input shape
  44. is the same as the output.
  45. This may introduce artifacts
  46. batch_norm (bool): Use BatchNorm after layers with an
  47. activation function
  48. up_mode (str): one of 'upconv' or 'upsample'.
  49. 'upconv' will use transposed convolutions for
  50. learned upsampling.
  51. 'upsample' will use bilinear upsampling.
  52. """
  53. super(UNet, self).__init__()
  54. assert up_mode in ('upconv', 'upsample')
  55. self.padding = padding
  56. self.depth = depth
  57. prev_channels = in_channels
  58. self.down_path = nn.ModuleList()
  59. for i in range(depth):
  60. self.down_path.append(UNetConvBlock(prev_channels, 2**(wf+i),
  61. padding, batch_norm))
  62. prev_channels = 2**(wf+i)
  63. self.up_path = nn.ModuleList()
  64. for i in reversed(range(depth - 1)):
  65. self.up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode,
  66. padding, batch_norm))
  67. prev_channels = 2**(wf+i)
  68. self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)
  69. def forward(self, x):
  70. blocks = []
  71. for i, down in enumerate(self.down_path):
  72. x = down(x)
  73. if i != len(self.down_path)-1:
  74. blocks.append(x)
  75. x = F.avg_pool2d(x, 2)
  76. for i, up in enumerate(self.up_path):
  77. x = up(x, blocks[-i-1])
  78. return self.last(x)
  79. class UNetConvBlock(nn.Module):
  80. def __init__(self, in_size, out_size, padding, batch_norm):
  81. super(UNetConvBlock, self).__init__()
  82. block = []
  83. block.append(nn.Conv2d(in_size, out_size, kernel_size=3,
  84. padding=int(padding)))
  85. block.append(nn.ReLU())
  86. # block.append(nn.LeakyReLU())
  87. if batch_norm:
  88. block.append(nn.BatchNorm2d(out_size))
  89. block.append(nn.Conv2d(out_size, out_size, kernel_size=3,
  90. padding=int(padding)))
  91. block.append(nn.ReLU())
  92. # block.append(nn.LeakyReLU())
  93. if batch_norm:
  94. block.append(nn.BatchNorm2d(out_size))
  95. self.block = nn.Sequential(*block)
  96. def forward(self, x):
  97. out = self.block(x)
  98. return out
  99. class UNetUpBlock(nn.Module):
  100. def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
  101. super(UNetUpBlock, self).__init__()
  102. if up_mode == 'upconv':
  103. self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2,
  104. stride=2)
  105. elif up_mode == 'upsample':
  106. self.up = nn.Sequential(nn.Upsample(mode='bilinear', scale_factor=2),
  107. nn.Conv2d(in_size, out_size, kernel_size=1))
  108. self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)
  109. def center_crop(self, layer, target_size):
  110. _, _, layer_height, layer_width = layer.size()
  111. diff_y = (layer_height - target_size[0]) // 2
  112. diff_x = (layer_width - target_size[1]) // 2
  113. return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])]
  114. def forward(self, x, bridge):
  115. up = self.up(x)
  116. crop1 = self.center_crop(bridge, up.shape[2:])
  117. out = torch.cat([up, crop1], 1)
  118. out = self.conv_block(out)
  119. return out