model.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import math
  2. import random
  3. from collections import namedtuple
  4. import torch
  5. from torch import nn as nn
  6. import torch.nn.functional as F
  7. from util.logconf import logging
  8. from util.unet import UNet
  9. log = logging.getLogger(__name__)
  10. # log.setLevel(logging.WARN)
  11. # log.setLevel(logging.INFO)
  12. log.setLevel(logging.DEBUG)
  13. class UNetWrapper(nn.Module):
  14. def __init__(self, **kwargs):
  15. super().__init__()
  16. self.input_batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
  17. self.unet = UNet(**kwargs)
  18. self.final = nn.Sigmoid()
  19. self._init_weights()
  20. def _init_weights(self):
  21. init_set = {
  22. nn.Conv2d,
  23. nn.Conv3d,
  24. nn.ConvTranspose2d,
  25. nn.ConvTranspose3d,
  26. nn.Linear,
  27. }
  28. for m in self.modules():
  29. if type(m) in init_set:
  30. nn.init.kaiming_normal_(
  31. m.weight.data, mode='fan_out', nonlinearity='relu', a=0
  32. )
  33. if m.bias is not None:
  34. fan_in, fan_out = \
  35. nn.init._calculate_fan_in_and_fan_out(m.weight.data)
  36. bound = 1 / math.sqrt(fan_out)
  37. nn.init.normal_(m.bias, -bound, bound)
  38. # nn.init.constant_(self.unet.last.bias, -4)
  39. # nn.init.constant_(self.unet.last.bias, 4)
  40. def forward(self, input_batch):
  41. bn_output = self.input_batchnorm(input_batch)
  42. un_output = self.unet(bn_output)
  43. fn_output = self.final(un_output)
  44. return fn_output
  45. class SegmentationAugmentation(nn.Module):
  46. def __init__(
  47. self, flip=None, offset=None, scale=None, rotate=None, noise=None
  48. ):
  49. super().__init__()
  50. self.flip = flip
  51. self.offset = offset
  52. self.scale = scale
  53. self.rotate = rotate
  54. self.noise = noise
  55. def forward(self, input_g, label_g):
  56. transform_t = self._build2dTransformMatrix()
  57. transform_t = transform_t.expand(input_g.shape[0], -1, -1)
  58. transform_t = transform_t.to(input_g.device, torch.float32)
  59. affine_t = F.affine_grid(transform_t[:,:2],
  60. input_g.size(), align_corners=False)
  61. augmented_input_g = F.grid_sample(input_g,
  62. affine_t, padding_mode='border',
  63. align_corners=False)
  64. augmented_label_g = F.grid_sample(label_g.to(torch.float32),
  65. affine_t, padding_mode='border',
  66. align_corners=False)
  67. if self.noise:
  68. noise_t = torch.randn_like(augmented_input_g)
  69. noise_t *= self.noise
  70. augmented_input_g += noise_t
  71. return augmented_input_g, augmented_label_g > 0.5
  72. def _build2dTransformMatrix(self):
  73. transform_t = torch.eye(3)
  74. for i in range(2):
  75. if self.flip:
  76. if random.random() > 0.5:
  77. transform_t[i,i] *= -1
  78. if self.offset:
  79. offset_float = self.offset
  80. random_float = (random.random() * 2 - 1)
  81. transform_t[2,i] = offset_float * random_float
  82. if self.scale:
  83. scale_float = self.scale
  84. random_float = (random.random() * 2 - 1)
  85. transform_t[i,i] *= 1.0 + scale_float * random_float
  86. if self.rotate:
  87. angle_rad = random.random() * math.pi * 2
  88. s = math.sin(angle_rad)
  89. c = math.cos(angle_rad)
  90. rotation_t = torch.tensor([
  91. [c, -s, 0],
  92. [s, c, 0],
  93. [0, 0, 1]])
  94. transform_t @= rotation_t
  95. return transform_t
  96. # MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask')
  97. #
  98. # class SegmentationMask(nn.Module):
  99. # def __init__(self):
  100. # super().__init__()
  101. #
  102. # self.conv_list = nn.ModuleList([
  103. # self._make_circle_conv(radius) for radius in range(1, 8)
  104. # ])
  105. #
  106. # def _make_circle_conv(self, radius):
  107. # diameter = 1 + radius * 2
  108. #
  109. # a = torch.linspace(-1, 1, steps=diameter)**2
  110. # b = (a[None] + a[:, None])**0.5
  111. #
  112. # circle_weights = (b <= 1.0).to(torch.float32)
  113. #
  114. # conv = nn.Conv2d(1, 1, kernel_size=diameter, padding=radius, bias=False)
  115. # conv.weight.data.fill_(1)
  116. # conv.weight.data *= circle_weights / circle_weights.sum()
  117. #
  118. # return conv
  119. #
  120. #
  121. # def erode(self, input_mask, radius, threshold=1):
  122. # conv = self.conv_list[radius - 1]
  123. # input_float = input_mask.to(torch.float32)
  124. # result = conv(input_float)
  125. #
  126. # # log.debug(['erode in ', radius, threshold, input_float.min().item(), input_float.mean().item(), input_float.max().item()])
  127. # # log.debug(['erode out', radius, threshold, result.min().item(), result.mean().item(), result.max().item()])
  128. #
  129. # return result >= threshold
  130. #
  131. # def deposit(self, input_mask, radius, threshold=0):
  132. # conv = self.conv_list[radius - 1]
  133. # input_float = input_mask.to(torch.float32)
  134. # result = conv(input_float)
  135. #
  136. # # log.debug(['deposit in ', radius, threshold, input_float.min().item(), input_float.mean().item(), input_float.max().item()])
  137. # # log.debug(['deposit out', radius, threshold, result.min().item(), result.mean().item(), result.max().item()])
  138. #
  139. # return result > threshold
  140. #
  141. # def fill_cavity(self, input_mask):
  142. # cumsum = input_mask.cumsum(-1)
  143. # filled_mask = (cumsum > 0)
  144. # filled_mask &= (cumsum < cumsum[..., -1:])
  145. # cumsum = input_mask.cumsum(-2)
  146. # filled_mask &= (cumsum > 0)
  147. # filled_mask &= (cumsum < cumsum[..., -1:, :])
  148. #
  149. # return filled_mask
  150. #
  151. #
  152. # def forward(self, input_g, raw_pos_g):
  153. # gcc_g = input_g + 1
  154. #
  155. # with torch.no_grad():
  156. # # log.info(['gcc_g', gcc_g.min(), gcc_g.mean(), gcc_g.max()])
  157. #
  158. # raw_dense_mask = gcc_g > 0.7
  159. # dense_mask = self.deposit(raw_dense_mask, 2)
  160. # dense_mask = self.erode(dense_mask, 6)
  161. # dense_mask = self.deposit(dense_mask, 4)
  162. #
  163. # body_mask = self.fill_cavity(dense_mask)
  164. # air_mask = self.deposit(body_mask & ~dense_mask, 5)
  165. # air_mask = self.erode(air_mask, 6)
  166. #
  167. # lung_mask = self.deposit(air_mask, 5)
  168. #
  169. # raw_candidate_mask = gcc_g > 0.4
  170. # raw_candidate_mask &= air_mask
  171. # candidate_mask = self.erode(raw_candidate_mask, 1)
  172. # candidate_mask = self.deposit(candidate_mask, 1)
  173. #
  174. # pos_mask = self.deposit((raw_pos_g > 0.5) & lung_mask, 2)
  175. #
  176. # neg_mask = self.deposit(candidate_mask, 1)
  177. # neg_mask &= ~pos_mask
  178. # neg_mask &= lung_mask
  179. #
  180. # # label_g = (neg_mask | pos_mask).to(torch.float32)
  181. # label_g = (pos_mask).to(torch.float32)
  182. # neg_g = neg_mask.to(torch.float32)
  183. # pos_g = pos_mask.to(torch.float32)
  184. #
  185. # mask_dict = {
  186. # 'raw_dense_mask': raw_dense_mask,
  187. # 'dense_mask': dense_mask,
  188. # 'body_mask': body_mask,
  189. # 'air_mask': air_mask,
  190. # 'raw_candidate_mask': raw_candidate_mask,
  191. # 'candidate_mask': candidate_mask,
  192. # 'lung_mask': lung_mask,
  193. # 'neg_mask': neg_mask,
  194. # 'pos_mask': pos_mask,
  195. # }
  196. #
  197. # return label_g, neg_g, pos_g, lung_mask, mask_dict