model.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import math
  2. import numpy as np
  3. import torch
  4. from torch import nn as nn
  5. from util.logconf import logging
  6. log = logging.getLogger(__name__)
  7. # log.setLevel(logging.WARN)
  8. # log.setLevel(logging.INFO)
  9. log.setLevel(logging.DEBUG)
  10. import random
  11. def augment3d(inp):
  12. transform_t = torch.eye(4, dtype=torch.float32)
  13. for i in range(3):
  14. if True: #'flip' in augmentation_dict:
  15. if random.random() > 0.5:
  16. transform_t[i,i] *= -1
  17. if True: #'offset' in augmentation_dict:
  18. offset_float = 0.1
  19. random_float = (random.random() * 2 - 1)
  20. transform_t[3,i] = offset_float * random_float
  21. if True:
  22. angle_rad = random.random() * np.pi * 2
  23. s = np.sin(angle_rad)
  24. c = np.cos(angle_rad)
  25. rotation_t = torch.tensor([
  26. [c, -s, 0, 0],
  27. [s, c, 0, 0],
  28. [0, 0, 1, 0],
  29. [0, 0, 0, 1],
  30. ], dtype=torch.float32)
  31. transform_t @= rotation_t
  32. #print(inp.shape, transform_t[:3].unsqueeze(0).expand(inp.size(0), -1, -1).shape)
  33. affine_t = torch.nn.functional.affine_grid(
  34. transform_t[:3].unsqueeze(0).expand(inp.size(0), -1, -1).cuda(),
  35. inp.shape,
  36. align_corners=False,
  37. )
  38. augmented_chunk = torch.nn.functional.grid_sample(
  39. inp,
  40. affine_t,
  41. padding_mode='border',
  42. align_corners=False,
  43. )
  44. if False: #'noise' in augmentation_dict:
  45. noise_t = torch.randn_like(augmented_chunk)
  46. noise_t *= augmentation_dict['noise']
  47. augmented_chunk += noise_t
  48. return augmented_chunk
  49. class LunaModel(nn.Module):
  50. def __init__(self, in_channels=1, conv_channels=8):
  51. super().__init__()
  52. self.tail_batchnorm = nn.BatchNorm3d(1)
  53. self.block1 = LunaBlock(in_channels, conv_channels)
  54. self.block2 = LunaBlock(conv_channels, conv_channels * 2)
  55. self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4)
  56. self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8)
  57. # self.head_linear = nn.Linear(1152, 2)
  58. self.head_linear = nn.Linear(1152, 2)
  59. self.head_activation = nn.Softmax(dim=1)
  60. self._init_weights()
  61. # see also https://github.com/pytorch/pytorch/issues/18182
  62. def _init_weights(self):
  63. for m in self.modules():
  64. if type(m) in {
  65. nn.Linear,
  66. nn.Conv3d,
  67. nn.Conv2d,
  68. nn.ConvTranspose2d,
  69. nn.ConvTranspose3d,
  70. }:
  71. nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='relu')
  72. if m.bias is not None:
  73. fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
  74. bound = 1 / math.sqrt(fan_out)
  75. nn.init.normal_(m.bias, -bound, bound)
  76. def forward(self, input_batch):
  77. bn_output = self.tail_batchnorm(input_batch)
  78. block_out = self.block1(bn_output)
  79. block_out = self.block2(block_out)
  80. block_out = self.block3(block_out)
  81. block_out = self.block4(block_out)
  82. conv_flat = block_out.view(
  83. block_out.size(0),
  84. -1,
  85. )
  86. linear_output = self.head_linear(conv_flat)
  87. return linear_output, self.head_activation(linear_output)
  88. class ModifiedLunaModel(nn.Sequential):
  89. def __init__(self, in_channels=1, conv_channels=32):
  90. super().__init__(
  91. nn.BatchNorm3d(1),
  92. nn.Conv3d(in_channels, conv_channels, (1, 5, 5), padding=(0, 2, 2)),
  93. nn.ReLU(),
  94. nn.MaxPool3d(2),
  95. nn.Conv3d(conv_channels, 2 * conv_channels, (1, 5, 5), padding=(0, 2, 2)),
  96. nn.ReLU(),
  97. nn.BatchNorm3d(2 * conv_channels),
  98. nn.MaxPool3d(2),
  99. nn.Conv3d(2 * conv_channels, 4 * conv_channels, (1, 3, 3), padding=(0, 1, 1)),
  100. nn.ReLU(),
  101. nn.MaxPool3d(2),
  102. nn.Conv3d(4 * conv_channels, 8 * conv_channels, (1, 3, 3), padding=(0, 1, 1)),
  103. nn.ReLU(),
  104. nn.MaxPool3d(2),
  105. nn.Conv3d(8 * conv_channels, 16 * conv_channels, (1, 3, 3), padding=(0, 1, 1)),
  106. nn.ReLU(),
  107. nn.Flatten(),
  108. nn.Linear(18 * 16 * conv_channels, 512),
  109. nn.ReLU(),
  110. nn.Linear(512, 256),
  111. nn.ReLU(),
  112. nn.Linear(256, 2)
  113. )
  114. self._init_weights()
  115. def forward(self, x):
  116. x = super().forward(x)
  117. return x, nn.functional.softmax(x, 1)
  118. def _init_weights(self):
  119. for m in self.modules():
  120. if isinstance(m, nn.Conv3d):
  121. nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
  122. if m.bias is not None:
  123. nn.init.zeros_(m.bias)
  124. elif isinstance(m, nn.Linear):
  125. nn.init.kaiming_normal_(m.weight)
  126. if m.bias is not None:
  127. nn.init.zeros_(m.bias)
  128. class LunaBlock(nn.Module):
  129. def __init__(self, in_channels, conv_channels):
  130. super().__init__()
  131. self.conv1 = nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True)
  132. self.relu1 = nn.ReLU(inplace=True)
  133. self.conv2 = nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True)
  134. self.relu2 = nn.ReLU(inplace=True)
  135. self.maxpool = nn.MaxPool3d(2, 2)
  136. def forward(self, input_batch):
  137. block_out = self.conv1(input_batch)
  138. block_out = self.relu1(block_out)
  139. block_out = self.conv2(block_out)
  140. block_out = self.relu2(block_out)
  141. return self.maxpool(block_out)