fuse_conv_bn.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. def _fuse_conv_bn(conv, bn):
  5. """Fuse conv and bn into one module.
  6. Args:
  7. conv (nn.Module): Conv to be fused.
  8. bn (nn.Module): BN to be fused.
  9. Returns:
  10. nn.Module: Fused module.
  11. """
  12. conv_w = conv.weight
  13. conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
  14. bn.running_mean)
  15. factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
  16. conv.weight = nn.Parameter(conv_w *
  17. factor.reshape([conv.out_channels, 1, 1, 1]))
  18. conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
  19. return conv
  20. def fuse_conv_bn(module):
  21. """Recursively fuse conv and bn in a module.
  22. During inference, the functionary of batch norm layers is turned off
  23. but only the mean and var alone channels are used, which exposes the
  24. chance to fuse it with the preceding conv layers to save computations and
  25. simplify network structures.
  26. Args:
  27. module (nn.Module): Module to be fused.
  28. Returns:
  29. nn.Module: Fused module.
  30. """
  31. last_conv = None
  32. last_conv_name = None
  33. for name, child in module.named_children():
  34. if isinstance(child,
  35. (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
  36. if last_conv is None: # only fuse BN that is after Conv
  37. continue
  38. fused_conv = _fuse_conv_bn(last_conv, child)
  39. module._modules[last_conv_name] = fused_conv
  40. # To reduce changes, set BN as Identity instead of deleting it.
  41. module._modules[name] = nn.Identity()
  42. last_conv = None
  43. elif isinstance(child, nn.Conv2d):
  44. last_conv = child
  45. last_conv_name = name
  46. else:
  47. fuse_conv_bn(child)
  48. return module