norm.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import torch
  2. import torch.nn as nn
  3. class FrozenBatchNorm2d(torch.nn.Module):
  4. """
  5. BatchNorm2d where the batch statistics and the affine parameters are fixed.
  6. Copy-paste from torchvision.misc.ops with added eps before rqsrt,
  7. without which any other models than torchvision.models.resnet[18,34,50,101]
  8. produce nans.
  9. """
  10. def __init__(self, n):
  11. super(FrozenBatchNorm2d, self).__init__()
  12. self.register_buffer("weight", torch.ones(n))
  13. self.register_buffer("bias", torch.zeros(n))
  14. self.register_buffer("running_mean", torch.zeros(n))
  15. self.register_buffer("running_var", torch.ones(n))
  16. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  17. missing_keys, unexpected_keys, error_msgs):
  18. num_batches_tracked_key = prefix + 'num_batches_tracked'
  19. if num_batches_tracked_key in state_dict:
  20. del state_dict[num_batches_tracked_key]
  21. super(FrozenBatchNorm2d, self)._load_from_state_dict(
  22. state_dict, prefix, local_metadata, strict,
  23. missing_keys, unexpected_keys, error_msgs)
  24. def forward(self, x):
  25. # move reshapes to the beginning
  26. # to make it fuser-friendly
  27. w = self.weight.reshape(1, -1, 1, 1)
  28. b = self.bias.reshape(1, -1, 1, 1)
  29. rv = self.running_var.reshape(1, -1, 1, 1)
  30. rm = self.running_mean.reshape(1, -1, 1, 1)
  31. eps = 1e-5
  32. scale = w * (rv + eps).rsqrt()
  33. bias = b - rm * scale
  34. return x * scale + bias