model.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import math
  2. from torch import nn as nn
  3. from util.logconf import logging
  4. from util.unet import UNet
  5. log = logging.getLogger(__name__)
  6. # log.setLevel(logging.WARN)
  7. # log.setLevel(logging.INFO)
  8. log.setLevel(logging.DEBUG)
  9. class UNetWrapper(nn.Module):
  10. def __init__(self, **kwargs):
  11. super().__init__()
  12. self.batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
  13. self.unet = UNet(**kwargs)
  14. self.final = nn.Sigmoid()
  15. for m in self.modules():
  16. if type(m) in {
  17. nn.Conv2d,
  18. nn.Conv3d,
  19. nn.ConvTranspose2d,
  20. nn.ConvTranspose3d,
  21. nn.Linear,
  22. }:
  23. nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='leaky_relu', a=0)
  24. if m.bias is not None:
  25. fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
  26. bound = 1 / math.sqrt(fan_out)
  27. nn.init.normal_(m.bias, -bound, bound)
  28. def forward(self, input):
  29. bn_output = self.batchnorm(input)
  30. un_output = self.unet(bn_output)
  31. fn_output = self.final(un_output)
  32. return fn_output