#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import math import torch.nn as nn def constant_init(module, val, bias=0): nn.init.constant_(module.weight, val) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) def xavier_init(module, gain=1, bias=0, distribution='normal'): assert distribution in ['uniform', 'normal'] if distribution == 'uniform': nn.init.xavier_uniform_(module.weight, gain=gain) else: nn.init.xavier_normal_(module.weight, gain=gain) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) def normal_init(module, mean=0, std=1, bias=0): nn.init.normal_(module.weight, mean, std) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) def uniform_init(module, a=0, b=1, bias=0): nn.init.uniform_(module.weight, a, b) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) def kaiming_init(module, a=0, mode='fan_out', nonlinearity='relu', bias=0, distribution='normal'): assert distribution in ['uniform', 'normal'] if distribution == 'uniform': nn.init.kaiming_uniform_(module.weight, a=a, mode=mode, nonlinearity=nonlinearity) else: nn.init.kaiming_normal_(module.weight, a=a, mode=mode, nonlinearity=nonlinearity) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) def caffe2_xavier_init(module, bias=0): # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch # Acknowledgment to FAIR's internal code kaiming_init(module, a=1, mode='fan_in', nonlinearity='leaky_relu', bias=bias, distribution='uniform') def c2_xavier_fill(module: nn.Module): """ Initialize `module.weight` using the "XavierFill" implemented in Caffe2. Also initializes `module.bias` to 0. Args: module (torch.nn.Module): module to initialize. """ # Caffe2 implementation of XavierFill in fact # corresponds to kaiming_uniform_ in PyTorch nn.init.kaiming_uniform_(module.weight, a=1) if module.bias is not None: nn.init.constant_(module.bias, 0) def c2_msra_fill(module: nn.Module): """ Initialize `module.weight` using the "MSRAFill" implemented in Caffe2. Also initializes `module.bias` to 0. Args: module (torch.nn.Module): module to initialize. """ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: nn.init.constant_(module.bias, 0) def init_weights(m: nn.Module, zero_init_final_gamma=False): """Performs ResNet-style weight initialization.""" if isinstance(m, nn.Conv2d): # Note that there is no bias due to BN fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out)) elif isinstance(m, nn.BatchNorm2d): zero_init_gamma = ( hasattr(m, "final_bn") and m.final_bn and zero_init_final_gamma ) m.weight.data.fill_(0.0 if zero_init_gamma else 1.0) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.weight.data.normal_(mean=0.0, std=0.01) m.bias.data.zero_()