| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- #!/usr/bin/env python3
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
- import math
- import torch.nn as nn
- def constant_init(module, val, bias=0):
- nn.init.constant_(module.weight, val)
- if hasattr(module, 'bias') and module.bias is not None:
- nn.init.constant_(module.bias, bias)
- def xavier_init(module, gain=1, bias=0, distribution='normal'):
- assert distribution in ['uniform', 'normal']
- if distribution == 'uniform':
- nn.init.xavier_uniform_(module.weight, gain=gain)
- else:
- nn.init.xavier_normal_(module.weight, gain=gain)
- if hasattr(module, 'bias') and module.bias is not None:
- nn.init.constant_(module.bias, bias)
- def normal_init(module, mean=0, std=1, bias=0):
- nn.init.normal_(module.weight, mean, std)
- if hasattr(module, 'bias') and module.bias is not None:
- nn.init.constant_(module.bias, bias)
- def uniform_init(module, a=0, b=1, bias=0):
- nn.init.uniform_(module.weight, a, b)
- if hasattr(module, 'bias') and module.bias is not None:
- nn.init.constant_(module.bias, bias)
- def kaiming_init(module,
- a=0,
- mode='fan_out',
- nonlinearity='relu',
- bias=0,
- distribution='normal'):
- assert distribution in ['uniform', 'normal']
- if distribution == 'uniform':
- nn.init.kaiming_uniform_(module.weight,
- a=a,
- mode=mode,
- nonlinearity=nonlinearity)
- else:
- nn.init.kaiming_normal_(module.weight,
- a=a,
- mode=mode,
- nonlinearity=nonlinearity)
- if hasattr(module, 'bias') and module.bias is not None:
- nn.init.constant_(module.bias, bias)
- def caffe2_xavier_init(module, bias=0):
- # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
- # Acknowledgment to FAIR's internal code
- kaiming_init(module,
- a=1,
- mode='fan_in',
- nonlinearity='leaky_relu',
- bias=bias,
- distribution='uniform')
- def c2_xavier_fill(module: nn.Module):
- """
- Initialize `module.weight` using the "XavierFill" implemented in Caffe2.
- Also initializes `module.bias` to 0.
- Args:
- module (torch.nn.Module): module to initialize.
- """
- # Caffe2 implementation of XavierFill in fact
- # corresponds to kaiming_uniform_ in PyTorch
- nn.init.kaiming_uniform_(module.weight, a=1)
- if module.bias is not None:
- nn.init.constant_(module.bias, 0)
- def c2_msra_fill(module: nn.Module):
- """
- Initialize `module.weight` using the "MSRAFill" implemented in Caffe2.
- Also initializes `module.bias` to 0.
- Args:
- module (torch.nn.Module): module to initialize.
- """
- nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
- if module.bias is not None:
- nn.init.constant_(module.bias, 0)
- def init_weights(m: nn.Module, zero_init_final_gamma=False):
- """Performs ResNet-style weight initialization."""
- if isinstance(m, nn.Conv2d):
- # Note that there is no bias due to BN
- fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
- elif isinstance(m, nn.BatchNorm2d):
- zero_init_gamma = (
- hasattr(m, "final_bn") and m.final_bn and zero_init_final_gamma
- )
- m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
- m.bias.data.zero_()
- elif isinstance(m, nn.Linear):
- m.weight.data.normal_(mean=0.0, std=0.01)
- m.bias.data.zero_()
|