| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- import torch.nn as nn
- def _fuse_conv_bn(conv, bn):
- """Fuse conv and bn into one module.
- Args:
- conv (nn.Module): Conv to be fused.
- bn (nn.Module): BN to be fused.
- Returns:
- nn.Module: Fused module.
- """
- conv_w = conv.weight
- conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
- bn.running_mean)
- factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
- conv.weight = nn.Parameter(conv_w *
- factor.reshape([conv.out_channels, 1, 1, 1]))
- conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
- return conv
- def fuse_conv_bn(module):
- """Recursively fuse conv and bn in a module.
- During inference, the functionary of batch norm layers is turned off
- but only the mean and var alone channels are used, which exposes the
- chance to fuse it with the preceding conv layers to save computations and
- simplify network structures.
- Args:
- module (nn.Module): Module to be fused.
- Returns:
- nn.Module: Fused module.
- """
- last_conv = None
- last_conv_name = None
- for name, child in module.named_children():
- if isinstance(child,
- (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
- if last_conv is None: # only fuse BN that is after Conv
- continue
- fused_conv = _fuse_conv_bn(last_conv, child)
- module._modules[last_conv_name] = fused_conv
- # To reduce changes, set BN as Identity instead of deleting it.
- module._modules[name] = nn.Identity()
- last_conv = None
- elif isinstance(child, nn.Conv2d):
- last_conv = child
- last_conv_name = name
- else:
- fuse_conv_bn(child)
- return module
|