|
|
@@ -0,0 +1,229 @@
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+import torch.nn.functional as F
|
|
|
+from typing import List
|
|
|
+
|
|
|
+try:
|
|
|
+ from .basic import get_clones, BasicConv, RTCBlock, TransformerLayer
|
|
|
+except:
|
|
|
+ from basic import get_clones, BasicConv, RTCBlock, TransformerLayer
|
|
|
+
|
|
|
+
|
|
|
+# Build PaFPN
|
|
|
+def build_fpn(cfg, in_dims, out_dim):
|
|
|
+ if cfg['fpn'] == 'hybrid_encoder':
|
|
|
+ return HybridEncoder(in_dims = in_dims,
|
|
|
+ out_dim = out_dim,
|
|
|
+ width = cfg['width'],
|
|
|
+ depth = cfg['depth'],
|
|
|
+ act_type = cfg['fpn_act'],
|
|
|
+ norm_type = cfg['fpn_norm'],
|
|
|
+ depthwise = cfg['fpn_depthwise'],
|
|
|
+ num_heads = cfg['en_num_heads'],
|
|
|
+ num_layers = cfg['en_num_layers'],
|
|
|
+ mlp_ratio = cfg['en_mlp_ratio'],
|
|
|
+ dropout = cfg['en_dropout'],
|
|
|
+ pe_temperature = cfg['pe_temperature'],
|
|
|
+ en_act_type = cfg['en_act'],
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ raise NotImplementedError("Unknown PaFPN: <{}>".format(cfg['fpn']))
|
|
|
+
|
|
|
+
|
|
|
+# ----------------- Feature Pyramid Network -----------------
|
|
|
+## Real-time Convolutional PaFPN
|
|
|
+class HybridEncoder(nn.Module):
|
|
|
+ def __init__(self,
|
|
|
+ in_dims :List = [256, 512, 512],
|
|
|
+ out_dim :int = 256,
|
|
|
+ width :float = 1.0,
|
|
|
+ depth :float = 1.0,
|
|
|
+ act_type :str = 'silu',
|
|
|
+ norm_type :str = 'BN',
|
|
|
+ depthwise :bool = False,
|
|
|
+ # Transformer's parameters
|
|
|
+ num_heads :int = 8,
|
|
|
+ num_layers :int = 1,
|
|
|
+ mlp_ratio :float = 4.0,
|
|
|
+ dropout :float = 0.1,
|
|
|
+ pe_temperature :float = 10000.,
|
|
|
+ en_act_type :str = 'gelu'
|
|
|
+ ) -> None:
|
|
|
+ super(HybridEncoder, self).__init__()
|
|
|
+ print('==============================')
|
|
|
+ print('FPN: {}'.format("RTC-PaFPN"))
|
|
|
+ # ---------------- Basic parameters ----------------
|
|
|
+ self.in_dims = in_dims
|
|
|
+ self.out_dim = round(out_dim * width)
|
|
|
+ self.width = width
|
|
|
+ self.depth = depth
|
|
|
+ self.num_heads = num_heads
|
|
|
+ self.num_layers = num_layers
|
|
|
+ self.mlp_ratio = mlp_ratio
|
|
|
+ self.pe_temperature = pe_temperature
|
|
|
+ self.pos_embed = None
|
|
|
+ c3, c4, c5 = in_dims
|
|
|
+
|
|
|
+ # ---------------- Input projs ----------------
|
|
|
+ self.reduce_layer_1 = BasicConv(c5, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
|
|
|
+ self.reduce_layer_2 = BasicConv(c4, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
|
|
|
+ self.reduce_layer_3 = BasicConv(c3, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
|
|
|
+
|
|
|
+ # ---------------- Downsample ----------------
|
|
|
+ self.dowmsample_layer_1 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
|
|
|
+ self.dowmsample_layer_2 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
|
|
|
+
|
|
|
+ # ---------------- Transformer Encoder ----------------
|
|
|
+ self.transformer_encoder = get_clones(
|
|
|
+ TransformerLayer(self.out_dim, num_heads, mlp_ratio, dropout, en_act_type), num_layers)
|
|
|
+
|
|
|
+ # ---------------- Top dwon FPN ----------------
|
|
|
+ ## P5 -> P4
|
|
|
+ self.top_down_layer_1 = RTCBlock(in_dim = self.out_dim * 2,
|
|
|
+ out_dim = self.out_dim,
|
|
|
+ num_blocks = round(3*depth),
|
|
|
+ shortcut = False,
|
|
|
+ act_type = act_type,
|
|
|
+ norm_type = norm_type,
|
|
|
+ depthwise = depthwise,
|
|
|
+ )
|
|
|
+ ## P4 -> P3
|
|
|
+ self.top_down_layer_2 = RTCBlock(in_dim = self.out_dim * 2,
|
|
|
+ out_dim = self.out_dim,
|
|
|
+ num_blocks = round(3*depth),
|
|
|
+ shortcut = False,
|
|
|
+ act_type = act_type,
|
|
|
+ norm_type = norm_type,
|
|
|
+ depthwise = depthwise,
|
|
|
+ )
|
|
|
+
|
|
|
+ # ---------------- Bottom up PAN----------------
|
|
|
+ ## P3 -> P4
|
|
|
+ self.bottom_up_layer_1 = RTCBlock(in_dim = self.out_dim * 2,
|
|
|
+ out_dim = self.out_dim,
|
|
|
+ num_blocks = round(3*depth),
|
|
|
+ shortcut = False,
|
|
|
+ act_type = act_type,
|
|
|
+ norm_type = norm_type,
|
|
|
+ depthwise = depthwise,
|
|
|
+ )
|
|
|
+ ## P4 -> P5
|
|
|
+ self.bottom_up_layer_2 = RTCBlock(in_dim = self.out_dim * 2,
|
|
|
+ out_dim = self.out_dim,
|
|
|
+ num_blocks = round(3*depth),
|
|
|
+ shortcut = False,
|
|
|
+ act_type = act_type,
|
|
|
+ norm_type = norm_type,
|
|
|
+ depthwise = depthwise,
|
|
|
+ )
|
|
|
+
|
|
|
+ self.init_weights()
|
|
|
+
|
|
|
+ def init_weights(self):
|
|
|
+ """Initialize the parameters."""
|
|
|
+ for m in self.modules():
|
|
|
+ if isinstance(m, torch.nn.Conv2d):
|
|
|
+ # In order to be consistent with the source code,
|
|
|
+ # reset the Conv2d initialization parameters
|
|
|
+ m.reset_parameters()
|
|
|
+
|
|
|
+ def build_2d_sincos_position_embedding(self, w, h, embed_dim=256, temperature=10000.):
|
|
|
+ assert embed_dim % 4 == 0, \
|
|
|
+ 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
|
|
|
+
|
|
|
+ # ----------- Check cahed pos_embed -----------
|
|
|
+ if self.pos_embed is not None and \
|
|
|
+ self.pos_embed.shape[2:] == [h, w]:
|
|
|
+ return self.pos_embed
|
|
|
+
|
|
|
+ # ----------- Generate grid coords -----------
|
|
|
+ grid_w = torch.arange(int(w), dtype=torch.float32)
|
|
|
+ grid_h = torch.arange(int(h), dtype=torch.float32)
|
|
|
+ grid_w, grid_h = torch.meshgrid([grid_w, grid_h]) # shape: [H, W]
|
|
|
+
|
|
|
+ pos_dim = embed_dim // 4
|
|
|
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
|
|
|
+ omega = 1. / (temperature**omega)
|
|
|
+
|
|
|
+ out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
|
|
|
+ out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
|
|
|
+
|
|
|
+ # shape: [1, N, C]
|
|
|
+ pos_embed = torch.concat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], axis=1)[None, :, :]
|
|
|
+ self.pos_embed = pos_embed
|
|
|
+
|
|
|
+ return pos_embed
|
|
|
+
|
|
|
+ def forward(self, features):
|
|
|
+ c3, c4, c5 = features
|
|
|
+
|
|
|
+ # -------- Input projs --------
|
|
|
+ p5 = self.reduce_layer_1(c5)
|
|
|
+ p4 = self.reduce_layer_2(c4)
|
|
|
+ p3 = self.reduce_layer_3(c3)
|
|
|
+
|
|
|
+ # -------- Transformer encoder --------
|
|
|
+ if self.transformer_encoder is not None:
|
|
|
+ for encoder in self.transformer_encoder:
|
|
|
+ channels, fmp_h, fmp_w = p5.shape[1:]
|
|
|
+ # [B, C, H, W] -> [B, N, C], N=HxW
|
|
|
+ src_flatten = p5.flatten(2).permute(0, 2, 1)
|
|
|
+ pos_embed = self.build_2d_sincos_position_embedding(
|
|
|
+ fmp_w, fmp_h, channels, self.pe_temperature)
|
|
|
+ memory = encoder(src_flatten, pos_embed=pos_embed)
|
|
|
+ # [B, N, C] -> [B, C, N] -> [B, C, H, W]
|
|
|
+ p5 = memory.permute(0, 2, 1).reshape([-1, channels, fmp_h, fmp_w])
|
|
|
+
|
|
|
+ # -------- Top down FPN --------
|
|
|
+ p5_up = F.interpolate(p5, scale_factor=2.0)
|
|
|
+ p4 = self.top_down_layer_1(torch.cat([p4, p5_up], dim=1))
|
|
|
+
|
|
|
+ p4_up = F.interpolate(p4, scale_factor=2.0)
|
|
|
+ p3 = self.top_down_layer_2(torch.cat([p3, p4_up], dim=1))
|
|
|
+
|
|
|
+ # -------- Bottom up PAN --------
|
|
|
+ p3_ds = self.dowmsample_layer_1(p3)
|
|
|
+ p4 = self.bottom_up_layer_1(torch.cat([p4, p3_ds], dim=1))
|
|
|
+
|
|
|
+ p4_ds = self.dowmsample_layer_2(p4)
|
|
|
+ p5 = self.bottom_up_layer_2(torch.cat([p5, p4_ds], dim=1))
|
|
|
+
|
|
|
+ out_feats = [p3, p4, p5]
|
|
|
+
|
|
|
+ return out_feats
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ import time
|
|
|
+ from thop import profile
|
|
|
+ cfg = {
|
|
|
+ 'width': 1.0,
|
|
|
+ 'depth': 1.0,
|
|
|
+ 'fpn': 'hybrid_encoder',
|
|
|
+ 'fpn_act': 'silu',
|
|
|
+ 'fpn_norm': 'BN',
|
|
|
+ 'fpn_depthwise': False,
|
|
|
+ 'en_num_heads': 8,
|
|
|
+ 'en_num_layers': 1,
|
|
|
+ 'en_mlp_ratio': 4.0,
|
|
|
+ 'en_dropout': 0.1,
|
|
|
+ 'pe_temperature': 10000.,
|
|
|
+ 'en_act': 'gelu',
|
|
|
+ }
|
|
|
+ fpn_dims = [256, 512, 1024]
|
|
|
+ out_dim = 256
|
|
|
+ pyramid_feats = [torch.randn(1, fpn_dims[0], 80, 80), torch.randn(1, fpn_dims[1], 40, 40), torch.randn(1, fpn_dims[2], 20, 20)]
|
|
|
+ model = build_fpn(cfg, fpn_dims, out_dim)
|
|
|
+
|
|
|
+ t0 = time.time()
|
|
|
+ outputs = model(pyramid_feats)
|
|
|
+ t1 = time.time()
|
|
|
+ print('Time: ', t1 - t0)
|
|
|
+ for out in outputs:
|
|
|
+ print(out.shape)
|
|
|
+
|
|
|
+ print('==============================')
|
|
|
+ flops, params = profile(model, inputs=(pyramid_feats, ), verbose=False)
|
|
|
+ print('==============================')
|
|
|
+ print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
|
|
|
+ print('Params : {:.2f} M'.format(params / 1e6))
|