model.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import math
  2. import numpy as np
  3. import torch
  4. from torch import nn as nn
  5. from util.logconf import logging
  6. log = logging.getLogger(__name__)
  7. # log.setLevel(logging.WARN)
  8. # log.setLevel(logging.INFO)
  9. log.setLevel(logging.DEBUG)
  10. import random
  11. def augment3d(inp):
  12. transform_t = torch.eye(4, dtype=torch.float32)
  13. for i in range(3):
  14. if True: #'flip' in augmentation_dict:
  15. if random.random() > 0.5:
  16. transform_t[i,i] *= -1
  17. if True: #'offset' in augmentation_dict:
  18. offset_float = 0.1
  19. random_float = (random.random() * 2 - 1)
  20. transform_t[3,i] = offset_float * random_float
  21. if True:
  22. angle_rad = random.random() * np.pi * 2
  23. s = np.sin(angle_rad)
  24. c = np.cos(angle_rad)
  25. rotation_t = torch.tensor([
  26. [c, -s, 0, 0],
  27. [s, c, 0, 0],
  28. [0, 0, 1, 0],
  29. [0, 0, 0, 1],
  30. ], dtype=torch.float32)
  31. transform_t @= rotation_t
  32. #print(inp.shape, transform_t[:3].unsqueeze(0).expand(inp.size(0), -1, -1).shape)
  33. affine_t = torch.nn.functional.affine_grid(
  34. transform_t[:3].unsqueeze(0).expand(inp.size(0), -1, -1).cuda(),
  35. inp.shape,
  36. align_corners=False,
  37. )
  38. augmented_chunk = torch.nn.functional.grid_sample(
  39. inp,
  40. affine_t,
  41. padding_mode='border',
  42. align_corners=False,
  43. )
  44. if False: #'noise' in augmentation_dict:
  45. noise_t = torch.randn_like(augmented_chunk)
  46. noise_t *= augmentation_dict['noise']
  47. augmented_chunk += noise_t
  48. return augmented_chunk
  49. class LunaModel(nn.Module):
  50. def __init__(self, in_channels=1, conv_channels=8):
  51. super().__init__()
  52. self.tail_batchnorm = nn.BatchNorm3d(1)
  53. self.block1 = LunaBlock(in_channels, conv_channels)
  54. self.block2 = LunaBlock(conv_channels, conv_channels * 2)
  55. self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4)
  56. self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8)
  57. # self.head_linear = nn.Linear(1152, 2)
  58. self.head_linear = nn.Linear(1152, 2)
  59. self.head_activation = nn.Softmax(dim=1)
  60. self._init_weights()
  61. # see also https://github.com/pytorch/pytorch/issues/18182
  62. def _init_weights(self):
  63. for m in self.modules():
  64. if type(m) in {
  65. nn.Linear,
  66. nn.Conv3d,
  67. nn.Conv2d,
  68. nn.ConvTranspose2d,
  69. nn.ConvTranspose3d,
  70. }:
  71. nn.init.kaiming_normal_(
  72. m.weight.data, a=0, mode='fan_out', nonlinearity='relu'
  73. )
  74. if m.bias is not None:
  75. fan_in, fan_out = \
  76. nn.init._calculate_fan_in_and_fan_out(m.weight.data)
  77. bound = 1 / math.sqrt(fan_out)
  78. nn.init.normal_(m.bias, -bound, bound)
  79. def forward(self, input_batch):
  80. bn_output = self.tail_batchnorm(input_batch)
  81. block_out = self.block1(bn_output)
  82. block_out = self.block2(block_out)
  83. block_out = self.block3(block_out)
  84. block_out = self.block4(block_out)
  85. conv_flat = block_out.view(
  86. block_out.size(0),
  87. -1,
  88. )
  89. linear_output = self.head_linear(conv_flat)
  90. return linear_output, self.head_activation(linear_output)
  91. class LunaBlock(nn.Module):
  92. def __init__(self, in_channels, conv_channels):
  93. super().__init__()
  94. self.conv1 = nn.Conv3d(
  95. in_channels, conv_channels, kernel_size=3, padding=1, bias=True
  96. )
  97. self.relu1 = nn.ReLU(inplace=True)
  98. self.conv2 = nn.Conv3d(
  99. conv_channels, conv_channels, kernel_size=3, padding=1, bias=True
  100. )
  101. self.relu2 = nn.ReLU(inplace=True)
  102. self.maxpool = nn.MaxPool3d(2, 2)
  103. def forward(self, input_batch):
  104. block_out = self.conv1(input_batch)
  105. block_out = self.relu1(block_out)
  106. block_out = self.conv2(block_out)
  107. block_out = self.relu2(block_out)
  108. return self.maxpool(block_out)