|
|
@@ -243,51 +243,71 @@ class PointwiseConv(nn.Module):
|
|
|
|
|
|
|
|
|
# ----------------- CNN Modules -----------------
|
|
|
-class Bottleneck(nn.Module):
|
|
|
- def __init__(self,
|
|
|
- in_dim,
|
|
|
- out_dim,
|
|
|
- expand_ratio = 0.5,
|
|
|
- kernel_sizes = [3, 3],
|
|
|
- shortcut = True,
|
|
|
- act_type = 'silu',
|
|
|
- norm_type = 'BN',
|
|
|
- depthwise = False,):
|
|
|
- super(Bottleneck, self).__init__()
|
|
|
- inter_dim = int(out_dim * expand_ratio)
|
|
|
- if depthwise:
|
|
|
- self.cv1 = nn.Sequential(
|
|
|
- DepthwiseConv(in_dim, in_dim, kernel_size=kernel_sizes[0], padding=kernel_sizes[0]//2, act_type=act_type, norm_type=norm_type),
|
|
|
- PointwiseConv(in_dim, inter_dim, act_type=act_type, norm_type=norm_type),
|
|
|
- )
|
|
|
- self.cv2 = nn.Sequential(
|
|
|
- DepthwiseConv(inter_dim, inter_dim, kernel_size=kernel_sizes[1], padding=kernel_sizes[1]//2, act_type=act_type, norm_type=norm_type),
|
|
|
- PointwiseConv(inter_dim, out_dim, act_type=act_type, norm_type=norm_type),
|
|
|
- )
|
|
|
- else:
|
|
|
- self.cv1 = BasicConv(in_dim, inter_dim, kernel_size=kernel_sizes[0], padding=kernel_sizes[0]//2, act_type=act_type, norm_type=norm_type)
|
|
|
- self.cv2 = BasicConv(inter_dim, out_dim, kernel_size=kernel_sizes[1], padding=kernel_sizes[1]//2, act_type=act_type, norm_type=norm_type)
|
|
|
- self.shortcut = shortcut and in_dim == out_dim
|
|
|
+class RepVggBlock(nn.Module):
|
|
|
+ def __init__(self, in_dim, out_dim, act_type='relu', norm_type='BN'):
|
|
|
+ super().__init__()
|
|
|
+ self.in_dim = in_dim
|
|
|
+ self.out_dim = out_dim
|
|
|
+ self.conv1 = BasicConv(in_dim, out_dim, kernel_size=3, padding=1, act_type=None, norm_type=norm_type)
|
|
|
+ self.conv2 = BasicConv(in_dim, out_dim, kernel_size=3, padding=1, act_type=None, norm_type=norm_type)
|
|
|
+ self.act = get_activation(act_type)
|
|
|
|
|
|
def forward(self, x):
|
|
|
- h = self.cv2(self.cv1(x))
|
|
|
+ if hasattr(self, 'conv'):
|
|
|
+ y = self.conv(x)
|
|
|
+ else:
|
|
|
+ y = self.conv1(x) + self.conv2(x)
|
|
|
|
|
|
- return x + h if self.shortcut else h
|
|
|
+ return self.act(y)
|
|
|
|
|
|
-class RTCBlock(nn.Module):
|
|
|
+ def convert_to_deploy(self):
|
|
|
+ if not hasattr(self, 'conv'):
|
|
|
+ self.conv = nn.Conv2d(self.in_dim, self.out_dim, 3, 1, padding=1)
|
|
|
+
|
|
|
+ kernel, bias = self.get_equivalent_kernel_bias()
|
|
|
+ self.conv.weight.data = kernel
|
|
|
+ self.conv.bias.data = bias
|
|
|
+
|
|
|
+ def get_equivalent_kernel_bias(self):
|
|
|
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
|
|
|
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
|
|
|
+
|
|
|
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1
|
|
|
+
|
|
|
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
|
|
+ if kernel1x1 is None:
|
|
|
+ return 0
|
|
|
+ else:
|
|
|
+ return F.pad(kernel1x1, [1, 1, 1, 1])
|
|
|
+
|
|
|
+ def _fuse_bn_tensor(self, branch: BasicConv):
|
|
|
+ if branch is None:
|
|
|
+ return 0, 0
|
|
|
+ kernel = branch.conv.weight
|
|
|
+ running_mean = branch.norm.running_mean
|
|
|
+ running_var = branch.norm.running_var
|
|
|
+ gamma = branch.norm.weight
|
|
|
+ beta = branch.norm.bias
|
|
|
+ eps = branch.norm.eps
|
|
|
+ std = (running_var + eps).sqrt()
|
|
|
+ t = (gamma / std).reshape(-1, 1, 1, 1)
|
|
|
+
|
|
|
+ return kernel * t, beta - running_mean * gamma / std
|
|
|
+
|
|
|
+class RepRTCBlock(nn.Module):
|
|
|
def __init__(self,
|
|
|
in_dim,
|
|
|
out_dim,
|
|
|
- num_blocks = 1,
|
|
|
- shortcut = False,
|
|
|
+ num_blocks = 3,
|
|
|
+ expansion = 1.0,
|
|
|
act_type = 'silu',
|
|
|
norm_type = 'BN',
|
|
|
- depthwise = False,):
|
|
|
- super(RTCBlock, self).__init__()
|
|
|
- self.inter_dim = out_dim // 2
|
|
|
- self.input_proj = BasicConv(in_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
|
|
|
+ ) -> None:
|
|
|
+ super(RepRTCBlock, self).__init__()
|
|
|
+ self.inter_dim = round(out_dim * expansion)
|
|
|
+ self.input_proj = BasicConv(in_dim, self.inter_dim * 2, kernel_size=1, act_type=act_type, norm_type=norm_type)
|
|
|
self.m = nn.Sequential(*(
|
|
|
- Bottleneck(self.inter_dim, self.inter_dim, 1.0, [3, 3], shortcut, act_type, norm_type, depthwise)
|
|
|
+ RepVggBlock(self.inter_dim, self.inter_dim, act_type, norm_type)
|
|
|
for _ in range(num_blocks)))
|
|
|
self.output_proj = BasicConv((2 + num_blocks) * self.inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
|
|
|
|
|
|
@@ -303,3 +323,4 @@ class RTCBlock(nn.Module):
|
|
|
out = self.output_proj(torch.cat(out, dim=1))
|
|
|
|
|
|
return out
|
|
|
+
|