| 1234567891011121314151617181920212223242526272829303132333435363738394041 |
- import math
- from torch import nn as nn
- from util.logconf import logging
- from util.unet import UNet
- log = logging.getLogger(__name__)
- # log.setLevel(logging.WARN)
- # log.setLevel(logging.INFO)
- log.setLevel(logging.DEBUG)
- class UNetWrapper(nn.Module):
- def __init__(self, **kwargs):
- super().__init__()
- self.batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
- self.unet = UNet(**kwargs)
- self.final = nn.Sigmoid()
- for m in self.modules():
- if type(m) in {
- nn.Conv2d,
- nn.Conv3d,
- nn.ConvTranspose2d,
- nn.ConvTranspose3d,
- nn.Linear,
- }:
- nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='leaky_relu', a=0)
- if m.bias is not None:
- fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
- bound = 1 / math.sqrt(fan_out)
- nn.init.normal_(m.bias, -bound, bound)
- def forward(self, input):
- bn_output = self.batchnorm(input)
- un_output = self.unet(bn_output)
- fn_output = self.final(un_output)
- return fn_output
|