model_segmentation.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. import torch
  2. from torch import nn as nn
  3. from util.logconf import logging
  4. from util.unet import UNet
  5. log = logging.getLogger(__name__)
  6. # log.setLevel(logging.WARN)
  7. # log.setLevel(logging.INFO)
  8. log.setLevel(logging.DEBUG)
  9. # torch.backends.cudnn.enabled = False
  10. class UNetWrapper(nn.Module):
  11. def __init__(self, **kwargs):
  12. super().__init__()
  13. self.batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
  14. self.unet = UNet(**kwargs)
  15. self.hardtanh = nn.Hardtanh(min_val=0, max_val=1)
  16. def forward(self, input):
  17. bn_output = self.batchnorm(input)
  18. un_output = self.unet(bn_output)
  19. ht_output = self.hardtanh(un_output)
  20. return ht_output
  21. class Simple2dSegmentationModel(nn.Module):
  22. def __init__(self, layers, in_channels, conv_channels, final_channels):
  23. super().__init__()
  24. self.layers = layers
  25. self.in_channels = in_channels
  26. self.conv_channels = conv_channels
  27. self.final_channels = final_channels
  28. layer_list = [
  29. nn.Conv2d(self.in_channels, self.conv_channels, kernel_size=3, padding=1),
  30. nn.BatchNorm2d(self.conv_channels),
  31. # nn.GroupNorm(1, self.conv_channels),
  32. # nn.ReLU(inplace=True),
  33. nn.LeakyReLU(inplace=True),
  34. ]
  35. for i in range(self.layers):
  36. layer_list.extend([
  37. nn.Conv2d(self.conv_channels, self.conv_channels, kernel_size=3, padding=1),
  38. nn.BatchNorm2d(self.conv_channels),
  39. # nn.GroupNorm(1, self.conv_channels),
  40. # nn.ReLU(inplace=True),
  41. nn.LeakyReLU(inplace=True),
  42. ])
  43. layer_list.extend([
  44. nn.Conv2d(self.conv_channels, self.final_channels, kernel_size=1, bias=True),
  45. nn.Hardtanh(min_val=0, max_val=1),
  46. ])
  47. self.layer_seq = nn.Sequential(*layer_list)
  48. def forward(self, in_data):
  49. return self.layer_seq(in_data)
  50. class Dense2dSegmentationModel(nn.Module):
  51. def __init__(self, layers, input_channels, conv_channels, bottleneck_channels, final_channels):
  52. super().__init__()
  53. self.layers = layers
  54. self.input_channels = input_channels
  55. self.conv_channels = conv_channels
  56. self.bottleneck_channels = bottleneck_channels
  57. self.final_channels = final_channels
  58. self.layer_list = nn.ModuleList()
  59. for i in range(layers):
  60. self.layer_list.append(
  61. Dense2dSegmentationBlock(
  62. input_channels + bottleneck_channels * i,
  63. conv_channels,
  64. bottleneck_channels,
  65. )
  66. )
  67. self.layer_list.append(
  68. Dense2dSegmentationBlock(
  69. input_channels + bottleneck_channels * layers,
  70. conv_channels,
  71. bottleneck_channels,
  72. final_channels,
  73. )
  74. )
  75. self.htanh_layer = nn.Hardtanh(min_val=0, max_val=1)
  76. def forward(self, input_tensor):
  77. concat_list = [input_tensor]
  78. for layer_block in self.layer_list:
  79. layer_output = layer_block(torch.cat(concat_list, dim=1))
  80. concat_list.append(layer_output)
  81. return self.htanh_layer(concat_list[-1])
  82. class Dense2dSegmentationBlock(nn.Module):
  83. def __init__(self, input_channels, conv_channels, bottleneck_channels, final_channels=None):
  84. super().__init__()
  85. self.input_channels = input_channels
  86. self.conv_channels = conv_channels
  87. self.bottleneck_channels = bottleneck_channels
  88. self.final_channels = final_channels or bottleneck_channels
  89. self.conv1_seq = nn.Sequential(
  90. nn.Conv2d(self.input_channels, self.bottleneck_channels, kernel_size=1),
  91. nn.Conv2d(self.bottleneck_channels, self.conv_channels, kernel_size=3, padding=1),
  92. nn.Conv2d(self.conv_channels, self.bottleneck_channels, kernel_size=1),
  93. # nn.BatchNorm2d(self.conv_channels),
  94. nn.GroupNorm(1, self.bottleneck_channels),
  95. # nn.ReLU(inplace=True),
  96. nn.LeakyReLU(inplace=True),
  97. )
  98. self.conv2_seq = nn.Sequential(
  99. nn.Conv2d(self.input_channels + self.bottleneck_channels, self.bottleneck_channels, kernel_size=1),
  100. nn.Conv2d(self.bottleneck_channels, self.conv_channels, kernel_size=3, padding=1),
  101. nn.Conv2d(self.conv_channels, self.final_channels, kernel_size=1),
  102. # nn.BatchNorm2d(self.conv_channels),
  103. nn.GroupNorm(1, self.final_channels),
  104. # nn.ReLU(inplace=True),
  105. nn.LeakyReLU(inplace=True),
  106. )
  107. def forward(self, input_tensor):
  108. conv1_tensor = self.conv1_seq(input_tensor)
  109. conv2_tensor = self.conv2_seq(torch.cat([input_tensor, conv1_tensor], dim=1))
  110. return conv2_tensor
  111. class SegmentationModel(nn.Module):
  112. def __init__(self, depth, in_channels, tail_channels=None, out_channels=None, final_channels=None):
  113. super().__init__()
  114. self.depth = depth
  115. # self.in_size = in_size
  116. # self.tailOut_size = in_size #self.in_size - 4
  117. # self.headIn_size = in_size #None
  118. # self.out_size = in_size #None
  119. self.in_channels = in_channels
  120. self.tailOut_channels = tail_channels or in_channels * 2
  121. self.headIn_channels = None
  122. self.out_channels = out_channels or self.tailOut_channels
  123. self.final_channels = final_channels
  124. # assert in_size % 2 == 0, repr([in_size, depth])
  125. self.tail_seq = nn.Sequential(
  126. nn.ReplicationPad3d(2),
  127. nn.Conv3d(self.in_channels, self.tailOut_channels, 3),
  128. nn.GroupNorm(1, self.tailOut_channels),
  129. nn.ReLU(inplace=True),
  130. nn.Conv3d(self.tailOut_channels, self.tailOut_channels, 3),
  131. nn.GroupNorm(1, self.tailOut_channels),
  132. nn.ReLU(inplace=True),
  133. )
  134. if depth:
  135. self.downsample_layer = nn.MaxPool3d(kernel_size=2, stride=2)
  136. self.child_layer = SegmentationModel(depth - 1, self.tailOut_channels)
  137. self.headIn_channels = self.in_channels + self.tailOut_channels + self.child_layer.out_channels
  138. # self.headIn_size = self.child_layer.out_size * 2
  139. # self.out_size = self.headIn_size #- 4
  140. # self.upsample_layer = nn.Upsample(scale_factor=2, mode='trilinear')
  141. else:
  142. self.downsample_layer = None
  143. self.child_layer = None
  144. # self.upsample_layer = None
  145. self.headIn_channels = self.in_channels + self.tailOut_channels
  146. # self.headIn_size = self.tailOut_size
  147. # self.out_size = self.headIn_size #- 4
  148. self.head_seq = nn.Sequential(
  149. nn.ReplicationPad3d(2),
  150. nn.Conv3d(self.headIn_channels, self.out_channels, 3),
  151. nn.GroupNorm(1, self.out_channels),
  152. nn.ReLU(inplace=True),
  153. nn.Conv3d(self.out_channels, self.out_channels, 3),
  154. nn.GroupNorm(1, self.out_channels),
  155. nn.ReLU(inplace=True),
  156. )
  157. if self.final_channels:
  158. self.final_seq = nn.Sequential(
  159. nn.ReplicationPad3d(1),
  160. nn.Conv3d(self.out_channels, self.final_channels, 1),
  161. )
  162. else:
  163. self.final_seq = None
  164. def forward(self, in_data):
  165. assert in_data.is_contiguous()
  166. try:
  167. tail_out = self.tail_seq(in_data)
  168. except:
  169. log.debug([in_data.size()])
  170. raise
  171. if self.downsample_layer:
  172. down_out = self.downsample_layer(tail_out)
  173. child_out = self.child_layer(down_out)
  174. # up_out = self.upsample_layer(child_out)
  175. up_out = nn.functional.interpolate(child_out, scale_factor=2, mode='trilinear')
  176. # crop_int = (tail_out.size(-1) - up_out.size(-1)) // 2
  177. # crop_out = tail_out[:, :, crop_int:-crop_int, crop_int:-crop_int, crop_int:-crop_int]
  178. # combined_out = torch.cat([crop_out, up_out], 1)
  179. combined_out = torch.cat([in_data, tail_out, up_out], 1)
  180. else:
  181. combined_out = torch.cat([in_data, tail_out], 1)
  182. head_out = self.head_seq(combined_out)
  183. if self.final_seq:
  184. final_out = self.final_seq(head_out)
  185. return final_out
  186. else:
  187. return head_out
  188. class DenseSegmentationModel(nn.Module):
  189. def __init__(self, depth, in_channels, conv_channels, final_channels=None):
  190. super().__init__()
  191. self.depth = depth
  192. self.in_channels = in_channels
  193. self.conv_channels = conv_channels
  194. self.final_channels = final_channels
  195. self.convA_seq = nn.Sequential(
  196. nn.Conv3d(self.in_channels, self.conv_channels // 4, 1),
  197. nn.ReplicationPad3d(1),
  198. nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
  199. nn.BatchNorm3d(self.conv_channels),
  200. nn.ReLU(inplace=True),
  201. )
  202. self.convB_seq = nn.Sequential(
  203. nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
  204. nn.ReplicationPad3d(1),
  205. nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
  206. nn.BatchNorm3d(self.conv_channels),
  207. nn.ReLU(inplace=True),
  208. )
  209. if self.depth:
  210. self.downsample_layer = nn.MaxPool3d(kernel_size=2, stride=2)
  211. self.child_layer = SegmentationModel(depth - 1, self.conv_channels, self.conv_channels * 2)
  212. self.upsample_layer = nn.Upsample(scale_factor=2, mode='trilinear')
  213. self.convC_seq = nn.Sequential(
  214. nn.Conv3d(self.in_channels + self.conv_channels * 3, self.conv_channels // 4, 1),
  215. nn.ReplicationPad3d(1),
  216. nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
  217. nn.BatchNorm3d(self.conv_channels),
  218. nn.ReLU(inplace=True),
  219. )
  220. else:
  221. self.downsample_layer = None
  222. self.child_layer = None
  223. self.upsample_layer = None
  224. self.convC_seq = nn.Sequential(
  225. nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
  226. nn.ReplicationPad3d(1),
  227. nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
  228. nn.BatchNorm3d(self.conv_channels),
  229. nn.ReLU(inplace=True),
  230. )
  231. self.convD_seq = nn.Sequential(
  232. nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
  233. nn.ReplicationPad3d(1),
  234. nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
  235. nn.BatchNorm3d(self.conv_channels),
  236. nn.ReLU(inplace=True),
  237. )
  238. if self.final_channels:
  239. self.final_seq = nn.Sequential(
  240. # nn.ReplicationPad3d(1),
  241. nn.Conv3d(self.conv_channels, self.final_channels, 1),
  242. )
  243. else:
  244. self.final_seq = None
  245. def forward(self, data_in):
  246. a_out = self.convA_seq(data_in)
  247. b_out = self.convB_seq(torch.cat([data_in, a_out], 1))
  248. if self.downsample_layer:
  249. down_out = self.downsample_layer(b_out)
  250. child_out = self.child_layer(down_out)
  251. up_out = self.upsample_layer(child_out)
  252. c_out = self.convC_seq(torch.cat([data_in, b_out, up_out], 1))
  253. else:
  254. c_out = self.convC_seq(torch.cat([data_in, b_out], 1))
  255. d_out = self.convD_seq(torch.cat([data_in, c_out], 1))
  256. if self.final_seq:
  257. return self.final_seq(d_out)
  258. else:
  259. return d_out