|
|
@@ -1,10 +1,43 @@
|
|
|
+import math
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
+# ----------------- MLP modules -----------------
|
|
|
+class MLP(nn.Module):
|
|
|
+ def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
|
|
|
+ super().__init__()
|
|
|
+ self.num_layers = num_layers
|
|
|
+ h = [hidden_dim] * (num_layers - 1)
|
|
|
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([in_dim] + h, h + [out_dim]))
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ for i, layer in enumerate(self.layers):
|
|
|
+ x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
|
|
+ return x
|
|
|
+
|
|
|
+class FFN(nn.Module):
|
|
|
+ def __init__(self, d_model=256, mlp_ratio=4.0, dropout=0., act_type='relu'):
|
|
|
+ super().__init__()
|
|
|
+ self.fpn_dim = round(d_model * mlp_ratio)
|
|
|
+ self.linear1 = nn.Linear(d_model, self.fpn_dim)
|
|
|
+ self.activation = get_activation(act_type)
|
|
|
+ self.dropout2 = nn.Dropout(dropout)
|
|
|
+ self.linear2 = nn.Linear(self.fpn_dim, d_model)
|
|
|
+ self.dropout3 = nn.Dropout(dropout)
|
|
|
+ self.norm = nn.LayerNorm(d_model)
|
|
|
+
|
|
|
+ def forward(self, src):
|
|
|
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
|
|
|
+ src = src + self.dropout3(src2)
|
|
|
+ src = self.norm(src)
|
|
|
+
|
|
|
+ return src
|
|
|
+
|
|
|
+
|
|
|
# ----------------- CNN modules -----------------
|
|
|
-def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
|
|
|
- conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
|
|
|
+def get_conv2d(c1, c2, k, p, s, g, bias=False):
|
|
|
+ conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, groups=g, bias=bias)
|
|
|
|
|
|
return conv
|
|
|
|
|
|
@@ -79,46 +112,172 @@ class FrozenBatchNorm2d(torch.nn.Module):
|
|
|
bias = b - rm * scale
|
|
|
return x * scale + bias
|
|
|
|
|
|
-class Conv(nn.Module):
|
|
|
+class BasicConv(nn.Module):
|
|
|
def __init__(self,
|
|
|
- c1, # in channels
|
|
|
- c2, # out channels
|
|
|
- k=1, # kernel size
|
|
|
- p=0, # padding
|
|
|
- s=1, # padding
|
|
|
- d=1, # dilation
|
|
|
- act_type :str = 'lrelu', # activation
|
|
|
- norm_type :str ='BN', # normalization
|
|
|
- depthwise :bool =False):
|
|
|
- super(Conv, self).__init__()
|
|
|
- convs = []
|
|
|
+ in_dim, # in channels
|
|
|
+ out_dim, # out channels
|
|
|
+ kernel_size=1, # kernel size
|
|
|
+ padding=0, # padding
|
|
|
+ stride=1, # padding
|
|
|
+ act_type :str = 'lrelu', # activation
|
|
|
+ norm_type :str = 'BN', # normalization
|
|
|
+ ):
|
|
|
+ super(BasicConv, self).__init__()
|
|
|
add_bias = False if norm_type else True
|
|
|
- if depthwise:
|
|
|
- convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
|
|
|
- # depthwise conv
|
|
|
- if norm_type:
|
|
|
- convs.append(get_norm(norm_type, c1))
|
|
|
- if act_type:
|
|
|
- convs.append(get_activation(act_type))
|
|
|
- # pointwise conv
|
|
|
- convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
|
|
|
- if norm_type:
|
|
|
- convs.append(get_norm(norm_type, c2))
|
|
|
- if act_type:
|
|
|
- convs.append(get_activation(act_type))
|
|
|
+ self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
|
|
|
+ self.norm = get_norm(norm_type, out_dim)
|
|
|
+ self.act = get_activation(act_type)
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ return self.act(self.norm(self.conv(x)))
|
|
|
+
|
|
|
+class DepthwiseConv(nn.Module):
|
|
|
+ def __init__(self,
|
|
|
+ in_dim, # in channels
|
|
|
+ out_dim, # out channels
|
|
|
+ kernel_size=1, # kernel size
|
|
|
+ padding=0, # padding
|
|
|
+ stride=1, # padding
|
|
|
+ act_type :str = None, # activation
|
|
|
+ norm_type :str = 'BN', # normalization
|
|
|
+ ):
|
|
|
+ super(DepthwiseConv, self).__init__()
|
|
|
+ assert in_dim == out_dim
|
|
|
+ add_bias = False if norm_type else True
|
|
|
+ self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=out_dim, bias=add_bias)
|
|
|
+ self.norm = get_norm(norm_type, out_dim)
|
|
|
+ self.act = get_activation(act_type)
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ return self.act(self.norm(self.conv(x)))
|
|
|
+
|
|
|
+class PointwiseConv(nn.Module):
|
|
|
+ def __init__(self,
|
|
|
+ in_dim, # in channels
|
|
|
+ out_dim, # out channels
|
|
|
+ act_type :str = 'lrelu', # activation
|
|
|
+ norm_type :str = 'BN', # normalization
|
|
|
+ ):
|
|
|
+ super(DepthwiseConv, self).__init__()
|
|
|
+ assert in_dim == out_dim
|
|
|
+ add_bias = False if norm_type else True
|
|
|
+ self.conv = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, g=1, bias=add_bias)
|
|
|
+ self.norm = get_norm(norm_type, out_dim)
|
|
|
+ self.act = get_activation(act_type)
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ return self.act(self.norm(self.conv(x)))
|
|
|
|
|
|
+## Yolov8's BottleNeck
|
|
|
+class Bottleneck(nn.Module):
|
|
|
+ def __init__(self,
|
|
|
+ in_dim,
|
|
|
+ out_dim,
|
|
|
+ expand_ratio = 0.5,
|
|
|
+ kernel_sizes = [3, 3],
|
|
|
+ shortcut = True,
|
|
|
+ act_type = 'silu',
|
|
|
+ norm_type = 'BN',
|
|
|
+ depthwise = False,):
|
|
|
+ super(Bottleneck, self).__init__()
|
|
|
+ inter_dim = int(out_dim * expand_ratio)
|
|
|
+ if depthwise:
|
|
|
+ self.cv1 = nn.Sequential(
|
|
|
+ DepthwiseConv(in_dim, in_dim, kernel_size=kernel_sizes[0], padding=kernel_sizes[0]//2, act_type=act_type, norm_type=norm_type),
|
|
|
+ PointwiseConv(in_dim, inter_dim, act_type=act_type, norm_type=norm_type),
|
|
|
+ )
|
|
|
+ self.cv2 = nn.Sequential(
|
|
|
+ DepthwiseConv(inter_dim, inter_dim, kernel_size=kernel_sizes[1], padding=kernel_sizes[1]//2, act_type=act_type, norm_type=norm_type),
|
|
|
+ PointwiseConv(inter_dim, out_dim, act_type=act_type, norm_type=norm_type),
|
|
|
+ )
|
|
|
else:
|
|
|
- convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
|
|
|
- if norm_type:
|
|
|
- convs.append(get_norm(norm_type, c2))
|
|
|
- if act_type:
|
|
|
- convs.append(get_activation(act_type))
|
|
|
-
|
|
|
- self.convs = nn.Sequential(*convs)
|
|
|
+ self.cv1 = BasicConv(in_dim, inter_dim, kernel_size=kernel_sizes[0], padding=kernel_sizes[0]//2, act_type=act_type, norm_type=norm_type)
|
|
|
+ self.cv2 = BasicConv(inter_dim, out_dim, kernel_size=kernel_sizes[1], padding=kernel_sizes[1]//2, act_type=act_type, norm_type=norm_type)
|
|
|
+ self.shortcut = shortcut and in_dim == out_dim
|
|
|
|
|
|
+ def forward(self, x):
|
|
|
+ h = self.cv2(self.cv1(x))
|
|
|
+
|
|
|
+ return x + h if self.shortcut else h
|
|
|
+
|
|
|
+# Yolov8's StageBlock
|
|
|
+class RTCBlock(nn.Module):
|
|
|
+ def __init__(self,
|
|
|
+ in_dim,
|
|
|
+ out_dim,
|
|
|
+ num_blocks = 1,
|
|
|
+ shortcut = False,
|
|
|
+ act_type = 'silu',
|
|
|
+ norm_type = 'BN',
|
|
|
+ depthwise = False,):
|
|
|
+ super(RTCBlock, self).__init__()
|
|
|
+ self.inter_dim = out_dim // 2
|
|
|
+ self.input_proj = BasicConv(in_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
|
|
|
+ self.m = nn.Sequential(*(
|
|
|
+ Bottleneck(self.inter_dim, self.inter_dim, 1.0, [3, 3], shortcut, act_type, norm_type, depthwise)
|
|
|
+ for _ in range(num_blocks)))
|
|
|
+ self.output_proj = BasicConv((2 + num_blocks) * self.inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
|
|
|
|
|
|
def forward(self, x):
|
|
|
- return self.convs(x)
|
|
|
+ # Input proj
|
|
|
+ x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
|
|
|
+ out = list([x1, x2])
|
|
|
+
|
|
|
+ # Bottlenecl
|
|
|
+ out.extend(m(out[-1]) for m in self.m)
|
|
|
+
|
|
|
+ # Output proj
|
|
|
+ out = self.output_proj(torch.cat(out, dim=1))
|
|
|
+
|
|
|
+ return out
|
|
|
|
|
|
|
|
|
# ----------------- Transformer modules -----------------
|
|
|
+## Transformer layer
|
|
|
+class TransformerLayer(nn.Module):
|
|
|
+ def __init__(self,
|
|
|
+ d_model :int = 256,
|
|
|
+ num_heads :int = 8,
|
|
|
+ mlp_ratio :float = 4.0,
|
|
|
+ dropout :float = 0.1,
|
|
|
+ act_type :str = "relu",
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ # ----------- Basic parameters -----------
|
|
|
+ self.d_model = d_model
|
|
|
+ self.num_heads = num_heads
|
|
|
+ self.mlp_ratio = mlp_ratio
|
|
|
+ self.dropout = dropout
|
|
|
+ self.act_type = act_type
|
|
|
+ # ----------- Basic parameters -----------
|
|
|
+ # Multi-head Self-Attn
|
|
|
+ self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
|
|
|
+ self.dropout = nn.Dropout(dropout)
|
|
|
+ self.norm = nn.LayerNorm(d_model)
|
|
|
+
|
|
|
+ # Feedforwaed Network
|
|
|
+ self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
|
|
|
+
|
|
|
+ def with_pos_embed(self, tensor, pos):
|
|
|
+ return tensor if pos is None else tensor + pos
|
|
|
+
|
|
|
+
|
|
|
+ def forward(self, src, pos):
|
|
|
+ """
|
|
|
+ Input:
|
|
|
+ src: [torch.Tensor] -> [B, N, C]
|
|
|
+ pos: [torch.Tensor] -> [B, N, C]
|
|
|
+ Output:
|
|
|
+ src: [torch.Tensor] -> [B, N, C]
|
|
|
+ """
|
|
|
+ q = k = self.with_pos_embed(src, pos)
|
|
|
+
|
|
|
+ # -------------- MHSA --------------
|
|
|
+ src2 = self.self_attn(q, k, value=src)
|
|
|
+ src = src + self.dropout(src2)
|
|
|
+ src = self.norm(src)
|
|
|
+
|
|
|
+ # -------------- FFN --------------
|
|
|
+ src = self.ffn(src)
|
|
|
+
|
|
|
+ return src
|