|
|
@@ -1,7 +1,9 @@
|
|
|
+import numpy as np
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
+# ---------------------------- 2D CNN ----------------------------
|
|
|
class SiLU(nn.Module):
|
|
|
"""export-friendly version of nn.SiLU()"""
|
|
|
|
|
|
@@ -34,7 +36,7 @@ def get_norm(norm_type, dim):
|
|
|
return nn.GroupNorm(num_groups=32, num_channels=dim)
|
|
|
|
|
|
|
|
|
-# Basic conv layer
|
|
|
+## Basic conv layer
|
|
|
class Conv(nn.Module):
|
|
|
def __init__(self,
|
|
|
c1, # in channels
|
|
|
@@ -77,7 +79,8 @@ class Conv(nn.Module):
|
|
|
return self.convs(x)
|
|
|
|
|
|
|
|
|
-# ELAN Block
|
|
|
+# ---------------------------- YOLOv7 Modules ----------------------------
|
|
|
+## ELAN-Block proposed by YOLOv7
|
|
|
class ELANBlock(nn.Module):
|
|
|
def __init__(self, in_dim, out_dim, expand_ratio=0.5, depth=2.0, act_type='silu', norm_type='BN', depthwise=False):
|
|
|
super(ELANBlock, self).__init__()
|
|
|
@@ -107,7 +110,7 @@ class ELANBlock(nn.Module):
|
|
|
return out
|
|
|
|
|
|
|
|
|
-# ELAN Block for PaFPN
|
|
|
+## PaFPN's ELAN-Block proposed by YOLOv7
|
|
|
class ELANBlockFPN(nn.Module):
|
|
|
def __init__(self, in_dim, out_dim, expand_ratio=0.5, nbranch=4, depth=1, act_type='silu', norm_type='BN', depthwise=False):
|
|
|
super(ELANBlockFPN, self).__init__()
|
|
|
@@ -147,7 +150,7 @@ class ELANBlockFPN(nn.Module):
|
|
|
return out
|
|
|
|
|
|
|
|
|
-# DownSample Block
|
|
|
+## DownSample Block proposed by YOLOv7
|
|
|
class DownSample(nn.Module):
|
|
|
def __init__(self, in_dim, out_dim, act_type='silu', norm_type='BN', depthwise=False):
|
|
|
super().__init__()
|
|
|
@@ -165,3 +168,177 @@ class DownSample(nn.Module):
|
|
|
out = torch.cat([x1, x2], dim=1)
|
|
|
|
|
|
return out
|
|
|
+
|
|
|
+
|
|
|
+# ---------------------------- RepConv Modules ----------------------------
|
|
|
+class RepConv(nn.Module):
|
|
|
+ """
|
|
|
+ The code referenced to https://github.com/WongKinYiu/yolov7/models/common.py
|
|
|
+ """
|
|
|
+ # Represented convolution
|
|
|
+ # https://arxiv.org/abs/2101.03697
|
|
|
+
|
|
|
+ def __init__(self, c1, c2, k=3, s=1, p=1, g=1, act_type='silu', deploy=False):
|
|
|
+ super(RepConv, self).__init__()
|
|
|
+ # -------------- Basic parameters --------------
|
|
|
+ self.deploy = deploy
|
|
|
+ self.groups = g
|
|
|
+ self.in_channels = c1
|
|
|
+ self.out_channels = c2
|
|
|
+
|
|
|
+ # -------------- Network parameters --------------
|
|
|
+ if deploy:
|
|
|
+ self.rbr_reparam = nn.Conv2d(c1, c2, k, s, p, groups=g, bias=True)
|
|
|
+
|
|
|
+ else:
|
|
|
+ self.rbr_identity = (nn.BatchNorm2d(num_features=c1) if c2 == c1 and s == 1 else None)
|
|
|
+
|
|
|
+ self.rbr_dense = nn.Sequential(
|
|
|
+ nn.Conv2d(c1, c2, k, s, p, groups=g, bias=False),
|
|
|
+ nn.BatchNorm2d(num_features=c2),
|
|
|
+ )
|
|
|
+
|
|
|
+ self.rbr_1x1 = nn.Sequential(
|
|
|
+ nn.Conv2d(c1, c2, kernel_size=1, stride=s, bias=False),
|
|
|
+ nn.BatchNorm2d(num_features=c2),
|
|
|
+ )
|
|
|
+ self.act = get_activation(act_type)
|
|
|
+
|
|
|
+
|
|
|
+ def forward(self, inputs):
|
|
|
+ if hasattr(self, "rbr_reparam"):
|
|
|
+ return self.act(self.rbr_reparam(inputs))
|
|
|
+
|
|
|
+ if self.rbr_identity is None:
|
|
|
+ id_out = 0
|
|
|
+ else:
|
|
|
+ id_out = self.rbr_identity(inputs)
|
|
|
+
|
|
|
+ return self.act(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
|
|
|
+
|
|
|
+ def get_equivalent_kernel_bias(self):
|
|
|
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
|
|
|
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
|
|
|
+ kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
|
|
|
+ return (
|
|
|
+ kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid,
|
|
|
+ bias3x3 + bias1x1 + biasid,
|
|
|
+ )
|
|
|
+
|
|
|
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
|
|
+ if kernel1x1 is None:
|
|
|
+ return 0
|
|
|
+ else:
|
|
|
+ return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
|
|
|
+
|
|
|
+ def _fuse_bn_tensor(self, branch):
|
|
|
+ if branch is None:
|
|
|
+ return 0, 0
|
|
|
+ if isinstance(branch, nn.Sequential):
|
|
|
+ kernel = branch[0].weight
|
|
|
+ running_mean = branch[1].running_mean
|
|
|
+ running_var = branch[1].running_var
|
|
|
+ gamma = branch[1].weight
|
|
|
+ beta = branch[1].bias
|
|
|
+ eps = branch[1].eps
|
|
|
+ else:
|
|
|
+ assert isinstance(branch, nn.BatchNorm2d)
|
|
|
+ if not hasattr(self, "id_tensor"):
|
|
|
+ input_dim = self.in_channels // self.groups
|
|
|
+ kernel_value = np.zeros(
|
|
|
+ (self.in_channels, input_dim, 3, 3), dtype=np.float32
|
|
|
+ )
|
|
|
+ for i in range(self.in_channels):
|
|
|
+ kernel_value[i, i % input_dim, 1, 1] = 1
|
|
|
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
|
|
|
+ kernel = self.id_tensor
|
|
|
+ running_mean = branch.running_mean
|
|
|
+ running_var = branch.running_var
|
|
|
+ gamma = branch.weight
|
|
|
+ beta = branch.bias
|
|
|
+ eps = branch.eps
|
|
|
+ std = (running_var + eps).sqrt()
|
|
|
+ t = (gamma / std).reshape(-1, 1, 1, 1)
|
|
|
+ return kernel * t, beta - running_mean * gamma / std
|
|
|
+
|
|
|
+ def repvgg_convert(self):
|
|
|
+ kernel, bias = self.get_equivalent_kernel_bias()
|
|
|
+ return (
|
|
|
+ kernel.detach().cpu().numpy(),
|
|
|
+ bias.detach().cpu().numpy(),
|
|
|
+ )
|
|
|
+
|
|
|
+ def fuse_conv_bn(self, conv, bn):
|
|
|
+
|
|
|
+ std = (bn.running_var + bn.eps).sqrt()
|
|
|
+ bias = bn.bias - bn.running_mean * bn.weight / std
|
|
|
+
|
|
|
+ t = (bn.weight / std).reshape(-1, 1, 1, 1)
|
|
|
+ weights = conv.weight * t
|
|
|
+
|
|
|
+ bn = nn.Identity()
|
|
|
+ conv = nn.Conv2d(in_channels = conv.in_channels,
|
|
|
+ out_channels = conv.out_channels,
|
|
|
+ kernel_size = conv.kernel_size,
|
|
|
+ stride=conv.stride,
|
|
|
+ padding = conv.padding,
|
|
|
+ dilation = conv.dilation,
|
|
|
+ groups = conv.groups,
|
|
|
+ bias = True,
|
|
|
+ padding_mode = conv.padding_mode)
|
|
|
+
|
|
|
+ conv.weight = torch.nn.Parameter(weights)
|
|
|
+ conv.bias = torch.nn.Parameter(bias)
|
|
|
+ return conv
|
|
|
+
|
|
|
+ def fuse_repvgg_block(self):
|
|
|
+ if self.deploy:
|
|
|
+ return
|
|
|
+
|
|
|
+ self.rbr_dense = self.fuse_conv_bn(self.rbr_dense[0], self.rbr_dense[1])
|
|
|
+
|
|
|
+ self.rbr_1x1 = self.fuse_conv_bn(self.rbr_1x1[0], self.rbr_1x1[1])
|
|
|
+ rbr_1x1_bias = self.rbr_1x1.bias
|
|
|
+ weight_1x1_expanded = torch.nn.functional.pad(self.rbr_1x1.weight, [1, 1, 1, 1])
|
|
|
+
|
|
|
+ # Fuse self.rbr_identity
|
|
|
+ if (isinstance(self.rbr_identity, nn.BatchNorm2d) or isinstance(self.rbr_identity, nn.modules.batchnorm.SyncBatchNorm)):
|
|
|
+ identity_conv_1x1 = nn.Conv2d(
|
|
|
+ in_channels=self.in_channels,
|
|
|
+ out_channels=self.out_channels,
|
|
|
+ kernel_size=1,
|
|
|
+ stride=1,
|
|
|
+ padding=0,
|
|
|
+ groups=self.groups,
|
|
|
+ bias=False)
|
|
|
+ identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.to(self.rbr_1x1.weight.data.device)
|
|
|
+ identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.squeeze().squeeze()
|
|
|
+
|
|
|
+ identity_conv_1x1.weight.data.fill_(0.0)
|
|
|
+ identity_conv_1x1.weight.data.fill_diagonal_(1.0)
|
|
|
+ identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.unsqueeze(2).unsqueeze(3)
|
|
|
+
|
|
|
+ identity_conv_1x1 = self.fuse_conv_bn(identity_conv_1x1, self.rbr_identity)
|
|
|
+ bias_identity_expanded = identity_conv_1x1.bias
|
|
|
+ weight_identity_expanded = torch.nn.functional.pad(identity_conv_1x1.weight, [1, 1, 1, 1])
|
|
|
+ else:
|
|
|
+ bias_identity_expanded = torch.nn.Parameter( torch.zeros_like(rbr_1x1_bias) )
|
|
|
+ weight_identity_expanded = torch.nn.Parameter( torch.zeros_like(weight_1x1_expanded) )
|
|
|
+
|
|
|
+ self.rbr_dense.weight = torch.nn.Parameter(self.rbr_dense.weight + weight_1x1_expanded + weight_identity_expanded)
|
|
|
+ self.rbr_dense.bias = torch.nn.Parameter(self.rbr_dense.bias + rbr_1x1_bias + bias_identity_expanded)
|
|
|
+
|
|
|
+ self.rbr_reparam = self.rbr_dense
|
|
|
+ self.deploy = True
|
|
|
+
|
|
|
+ if self.rbr_identity is not None:
|
|
|
+ del self.rbr_identity
|
|
|
+ self.rbr_identity = None
|
|
|
+
|
|
|
+ if self.rbr_1x1 is not None:
|
|
|
+ del self.rbr_1x1
|
|
|
+ self.rbr_1x1 = None
|
|
|
+
|
|
|
+ if self.rbr_dense is not None:
|
|
|
+ del self.rbr_dense
|
|
|
+ self.rbr_dense = None
|