weight_init.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
  3. import math
  4. import torch.nn as nn
  5. def constant_init(module, val, bias=0):
  6. nn.init.constant_(module.weight, val)
  7. if hasattr(module, 'bias') and module.bias is not None:
  8. nn.init.constant_(module.bias, bias)
  9. def xavier_init(module, gain=1, bias=0, distribution='normal'):
  10. assert distribution in ['uniform', 'normal']
  11. if distribution == 'uniform':
  12. nn.init.xavier_uniform_(module.weight, gain=gain)
  13. else:
  14. nn.init.xavier_normal_(module.weight, gain=gain)
  15. if hasattr(module, 'bias') and module.bias is not None:
  16. nn.init.constant_(module.bias, bias)
  17. def normal_init(module, mean=0, std=1, bias=0):
  18. nn.init.normal_(module.weight, mean, std)
  19. if hasattr(module, 'bias') and module.bias is not None:
  20. nn.init.constant_(module.bias, bias)
  21. def uniform_init(module, a=0, b=1, bias=0):
  22. nn.init.uniform_(module.weight, a, b)
  23. if hasattr(module, 'bias') and module.bias is not None:
  24. nn.init.constant_(module.bias, bias)
  25. def kaiming_init(module,
  26. a=0,
  27. mode='fan_out',
  28. nonlinearity='relu',
  29. bias=0,
  30. distribution='normal'):
  31. assert distribution in ['uniform', 'normal']
  32. if distribution == 'uniform':
  33. nn.init.kaiming_uniform_(module.weight,
  34. a=a,
  35. mode=mode,
  36. nonlinearity=nonlinearity)
  37. else:
  38. nn.init.kaiming_normal_(module.weight,
  39. a=a,
  40. mode=mode,
  41. nonlinearity=nonlinearity)
  42. if hasattr(module, 'bias') and module.bias is not None:
  43. nn.init.constant_(module.bias, bias)
  44. def caffe2_xavier_init(module, bias=0):
  45. # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
  46. # Acknowledgment to FAIR's internal code
  47. kaiming_init(module,
  48. a=1,
  49. mode='fan_in',
  50. nonlinearity='leaky_relu',
  51. bias=bias,
  52. distribution='uniform')
  53. def c2_xavier_fill(module: nn.Module):
  54. """
  55. Initialize `module.weight` using the "XavierFill" implemented in Caffe2.
  56. Also initializes `module.bias` to 0.
  57. Args:
  58. module (torch.nn.Module): module to initialize.
  59. """
  60. # Caffe2 implementation of XavierFill in fact
  61. # corresponds to kaiming_uniform_ in PyTorch
  62. nn.init.kaiming_uniform_(module.weight, a=1)
  63. if module.bias is not None:
  64. nn.init.constant_(module.bias, 0)
  65. def c2_msra_fill(module: nn.Module):
  66. """
  67. Initialize `module.weight` using the "MSRAFill" implemented in Caffe2.
  68. Also initializes `module.bias` to 0.
  69. Args:
  70. module (torch.nn.Module): module to initialize.
  71. """
  72. nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
  73. if module.bias is not None:
  74. nn.init.constant_(module.bias, 0)
  75. def init_weights(m: nn.Module, zero_init_final_gamma=False):
  76. """Performs ResNet-style weight initialization."""
  77. if isinstance(m, nn.Conv2d):
  78. # Note that there is no bias due to BN
  79. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  80. m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
  81. elif isinstance(m, nn.BatchNorm2d):
  82. zero_init_gamma = (
  83. hasattr(m, "final_bn") and m.final_bn and zero_init_final_gamma
  84. )
  85. m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
  86. m.bias.data.zero_()
  87. elif isinstance(m, nn.Linear):
  88. m.weight.data.normal_(mean=0.0, std=0.01)
  89. m.bias.data.zero_()