norm.py 1.2 KB

123456789101112131415161718192021222324252627282930313233
  1. import torch
  2. class FrozenBatchNorm2d(torch.nn.Module):
  3. def __init__(self, n):
  4. super(FrozenBatchNorm2d, self).__init__()
  5. self.register_buffer("weight", torch.ones(n))
  6. self.register_buffer("bias", torch.zeros(n))
  7. self.register_buffer("running_mean", torch.zeros(n))
  8. self.register_buffer("running_var", torch.ones(n))
  9. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  10. missing_keys, unexpected_keys, error_msgs):
  11. num_batches_tracked_key = prefix + 'num_batches_tracked'
  12. if num_batches_tracked_key in state_dict:
  13. del state_dict[num_batches_tracked_key]
  14. super(FrozenBatchNorm2d, self)._load_from_state_dict(
  15. state_dict, prefix, local_metadata, strict,
  16. missing_keys, unexpected_keys, error_msgs)
  17. def forward(self, x):
  18. # move reshapes to the beginning
  19. # to make it fuser-friendly
  20. w = self.weight.reshape(1, -1, 1, 1)
  21. b = self.bias.reshape(1, -1, 1, 1)
  22. rv = self.running_var.reshape(1, -1, 1, 1)
  23. rm = self.running_mean.reshape(1, -1, 1, 1)
  24. eps = 1e-5
  25. scale = w * (rv + eps).rsqrt()
  26. bias = b - rm * scale
  27. return x * scale + bias