model_seg.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import math
  2. from torch import nn as nn
  3. from util.logconf import logging
  4. from util.unet import UNet
  5. log = logging.getLogger(__name__)
  6. # log.setLevel(logging.WARN)
  7. # log.setLevel(logging.INFO)
  8. log.setLevel(logging.DEBUG)
  9. class UNetWrapper(nn.Module):
  10. def __init__(self, **kwargs):
  11. super().__init__()
  12. self.input_batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
  13. self.unet = UNet(**kwargs)
  14. self.final = nn.Sigmoid()
  15. self._init_weights()
  16. def _init_weights(self):
  17. init_set = {
  18. nn.Conv2d,
  19. nn.Conv3d,
  20. nn.ConvTranspose2d,
  21. nn.ConvTranspose3d,
  22. nn.Linear,
  23. }
  24. for m in self.modules():
  25. if type(m) in init_set:
  26. nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu', a=0)
  27. if m.bias is not None:
  28. fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
  29. bound = 1 / math.sqrt(fan_out)
  30. nn.init.normal_(m.bias, -bound, bound)
  31. def forward(self, input_batch):
  32. bn_output = self.input_batchnorm(input_batch)
  33. un_output = self.unet(bn_output)
  34. fn_output = self.final(un_output)
  35. return fn_output