|
|
@@ -34,6 +34,7 @@ def get_norm(norm_type, dim):
|
|
|
elif norm_type == 'GN':
|
|
|
return nn.GroupNorm(num_groups=32, num_channels=dim)
|
|
|
|
|
|
+## Basic Conv Module
|
|
|
class Conv(nn.Module):
|
|
|
def __init__(self,
|
|
|
c1, # in channels
|
|
|
@@ -76,78 +77,87 @@ class Conv(nn.Module):
|
|
|
def forward(self, x):
|
|
|
return self.convs(x)
|
|
|
|
|
|
+## Partial Conv Module
|
|
|
+class PartialConv(nn.Module):
|
|
|
+ def __init__(self, in_dim, out_dim, split_ratio=0.25, kernel_size=1, stride=1, act_type=None, norm_type=None):
|
|
|
+ super().__init__()
|
|
|
+ # ----------- Basic Parameters -----------
|
|
|
+ assert in_dim == out_dim
|
|
|
+ self.in_dim = in_dim
|
|
|
+ self.out_dim = out_dim
|
|
|
+ self.split_ratio = split_ratio
|
|
|
+ self.split_dim = round(in_dim * split_ratio)
|
|
|
+ self.untouched_dim = in_dim - self.split_dim
|
|
|
+ self.kernel_size = kernel_size
|
|
|
+ self.padding = kernel_size // 2
|
|
|
+ self.stride = stride
|
|
|
+ self.act_type = act_type
|
|
|
+ self.norm_type = norm_type
|
|
|
+ # ----------- Network Parameters -----------
|
|
|
+ self.partial_conv = Conv(self.split_dim, self.split_dim, self.kernel_size, self.padding, self.stride, act_type=act_type, norm_type=norm_type)
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ x1, x2 = torch.split(x, [self.split_dim, self.untouched_dim], dim=1)
|
|
|
+ x1 = self.partial_conv(x1)
|
|
|
+ x = torch.cat((x1, x2), 1)
|
|
|
+
|
|
|
+ return x
|
|
|
+
|
|
|
|
|
|
# ---------------------------- Base Modules ----------------------------
|
|
|
-## Multi-head Mixed Conv (MHMC)
|
|
|
-class MultiHeadMixedConv(nn.Module):
|
|
|
- def __init__(self, in_dim, out_dim, num_heads=4, shortcut=False, act_type='silu', norm_type='BN', depthwise=False):
|
|
|
+## Faster Module
|
|
|
+class FasterModule(nn.Module):
|
|
|
+ def __init__(self, in_dim, out_dim, split_ratio=0.25, kernel_size=3, stride=1, shortcut=True, act_type='silu', norm_type='BN'):
|
|
|
super().__init__()
|
|
|
- # -------------- Basic parameters --------------
|
|
|
+ # ----------- Basic Parameters -----------
|
|
|
self.in_dim = in_dim
|
|
|
self.out_dim = out_dim
|
|
|
- self.num_heads = num_heads
|
|
|
- self.shortcut = shortcut
|
|
|
- self.head_dim = in_dim // num_heads
|
|
|
- # -------------- Network parameters --------------
|
|
|
- ## Scale Modulation
|
|
|
- self.mixed_convs = nn.ModuleList()
|
|
|
- for i in range(num_heads):
|
|
|
- self.mixed_convs.append(
|
|
|
- Conv(self.head_dim, self.head_dim, k=2*i+1, p=i, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
|
|
|
- )
|
|
|
- ## Out-proj
|
|
|
- self.out_proj = Conv(in_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
-
|
|
|
+ self.split_ratio = split_ratio
|
|
|
+ self.shortcut = True if shortcut and in_dim == out_dim else False
|
|
|
+ self.act_type = act_type
|
|
|
+ self.norm_type = norm_type
|
|
|
+ # ----------- Network Parameters -----------
|
|
|
+ self.partial_conv = PartialConv(in_dim, in_dim, split_ratio, kernel_size, stride, act_type=None, norm_type=None)
|
|
|
+ self.expand_layer = Conv(in_dim, in_dim*2, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
+ self.project_layer = Conv(in_dim*2, out_dim, k=1, act_type=None, norm_type=None)
|
|
|
|
|
|
def forward(self, x):
|
|
|
- xs = torch.chunk(x, self.num_heads, dim=1)
|
|
|
- ys = [mixed_conv(x_h) for x_h, mixed_conv in zip(xs, self.mixed_convs)]
|
|
|
- out = self.out_proj(torch.cat(ys, dim=1))
|
|
|
+ h = self.project_layer(self.expand_layer(self.partial_conv(x)))
|
|
|
|
|
|
- return out + x if self.shortcut else out
|
|
|
+ return x + h if self.shortcut else h
|
|
|
|
|
|
-## Mixed Convolution Block
|
|
|
-class MCBlock(nn.Module):
|
|
|
- def __init__(self, in_dim, out_dim, nblocks=1, num_heads=4, shortcut=False, act_type='silu', norm_type='BN', depthwise=False):
|
|
|
+## CSP-style FasterBlock
|
|
|
+class FasterBlock(nn.Module):
|
|
|
+ def __init__(self, in_dim, out_dim, split_ratio=0.5, num_blocks=1, shortcut=True, act_type='silu', norm_type='BN'):
|
|
|
super().__init__()
|
|
|
# -------------- Basic parameters --------------
|
|
|
self.in_dim = in_dim
|
|
|
self.out_dim = out_dim
|
|
|
- self.nblocks = nblocks
|
|
|
- self.num_heads = num_heads
|
|
|
- self.shortcut = shortcut
|
|
|
+ self.split_ratio = split_ratio
|
|
|
+ self.num_blocks = num_blocks
|
|
|
self.inter_dim = in_dim // 2
|
|
|
# -------------- Network parameters --------------
|
|
|
- ## branch-1
|
|
|
- self.cv1 = Conv(self.in_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
- self.cv2 = Conv(self.in_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
- ## branch-2
|
|
|
- self.smblocks = nn.Sequential(*[
|
|
|
- MultiHeadMixedConv(self.inter_dim, self.inter_dim, self.num_heads, self.shortcut, act_type, norm_type, depthwise)
|
|
|
- for _ in range(nblocks)])
|
|
|
- ## out proj
|
|
|
+ self.cv1 = Conv(in_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
+ self.cv2 = Conv(in_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
+ self.blocks = nn.Sequential(*[
|
|
|
+ FasterModule(self.inter_dim, self.inter_dim, split_ratio, 3, 1, shortcut, act_type, norm_type)
|
|
|
+ for _ in range(self.num_blocks)])
|
|
|
self.out_proj = Conv(self.inter_dim*2, out_dim, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
- # branch-1
|
|
|
x1 = self.cv1(x)
|
|
|
- # branch-2
|
|
|
- x2 = self.smblocks(self.cv2(x))
|
|
|
- # output
|
|
|
- out = torch.cat([x1, x2], dim=1)
|
|
|
- out = self.out_proj(out)
|
|
|
+ x2 = self.blocks(self.cv2(x))
|
|
|
|
|
|
- return out
|
|
|
+ return self.out_proj(torch.cat([x1, x2], dim=1))
|
|
|
|
|
|
## DownSample Block
|
|
|
class DSBlock(nn.Module):
|
|
|
- def __init__(self, in_dim, out_dim, num_heads=4, act_type='silu', norm_type='BN', depthwise=False):
|
|
|
+ def __init__(self, in_dim, out_dim, act_type='silu', norm_type='BN', depthwise=False):
|
|
|
super().__init__()
|
|
|
self.in_dim = in_dim
|
|
|
self.out_dim = out_dim
|
|
|
self.inter_dim = out_dim // 2
|
|
|
- self.num_heads = num_heads
|
|
|
# branch-1
|
|
|
self.maxpool = nn.Sequential(
|
|
|
Conv(in_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type),
|
|
|
@@ -174,16 +184,15 @@ class DSBlock(nn.Module):
|
|
|
# ---------------------------- FPN Modules ----------------------------
|
|
|
## build fpn's core block
|
|
|
def build_fpn_block(cfg, in_dim, out_dim):
|
|
|
- if cfg['fpn_core_block'] == 'mcblock':
|
|
|
- layer = MCBlock(in_dim=in_dim,
|
|
|
- out_dim=out_dim,
|
|
|
- nblocks=round(cfg['depth'] * 3),
|
|
|
- num_heads=cfg['fpn_num_heads'],
|
|
|
- shortcut=False,
|
|
|
- act_type=cfg['fpn_act'],
|
|
|
- norm_type=cfg['fpn_norm'],
|
|
|
- depthwise=cfg['fpn_depthwise']
|
|
|
- )
|
|
|
+ if cfg['fpn_core_block'] == 'faster_block':
|
|
|
+ layer = FasterBlock(in_dim = in_dim,
|
|
|
+ out_dim = out_dim,
|
|
|
+ split_ratio = cfg['fpn_split_ratio'],
|
|
|
+ num_blocks = round(3 * cfg['depth']),
|
|
|
+ shortcut = False,
|
|
|
+ act_type = cfg['fpn_act'],
|
|
|
+ norm_type = cfg['fpn_norm'],
|
|
|
+ )
|
|
|
|
|
|
return layer
|
|
|
|