| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from typing import List
- try:
- from .basic import get_clones, BasicConv, RTCBlock, TransformerLayer
- except:
- from basic import get_clones, BasicConv, RTCBlock, TransformerLayer
- # Build PaFPN
- def build_fpn(cfg, in_dims, out_dim):
- if cfg['fpn'] == 'hybrid_encoder':
- return HybridEncoder(in_dims = in_dims,
- out_dim = out_dim,
- width = cfg['width'],
- depth = cfg['depth'],
- act_type = cfg['fpn_act'],
- norm_type = cfg['fpn_norm'],
- depthwise = cfg['fpn_depthwise'],
- num_heads = cfg['en_num_heads'],
- num_layers = cfg['en_num_layers'],
- mlp_ratio = cfg['en_mlp_ratio'],
- dropout = cfg['en_dropout'],
- pe_temperature = cfg['pe_temperature'],
- en_act_type = cfg['en_act'],
- )
- else:
- raise NotImplementedError("Unknown PaFPN: <{}>".format(cfg['fpn']))
- # ----------------- Feature Pyramid Network -----------------
- ## Real-time Convolutional PaFPN
- class HybridEncoder(nn.Module):
- def __init__(self,
- in_dims :List = [256, 512, 512],
- out_dim :int = 256,
- width :float = 1.0,
- depth :float = 1.0,
- act_type :str = 'silu',
- norm_type :str = 'BN',
- depthwise :bool = False,
- # Transformer's parameters
- num_heads :int = 8,
- num_layers :int = 1,
- mlp_ratio :float = 4.0,
- dropout :float = 0.1,
- pe_temperature :float = 10000.,
- en_act_type :str = 'gelu'
- ) -> None:
- super(HybridEncoder, self).__init__()
- print('==============================')
- print('FPN: {}'.format("RTC-PaFPN"))
- # ---------------- Basic parameters ----------------
- self.in_dims = in_dims
- self.out_dim = round(out_dim * width)
- self.width = width
- self.depth = depth
- self.num_heads = num_heads
- self.num_layers = num_layers
- self.mlp_ratio = mlp_ratio
- self.pe_temperature = pe_temperature
- self.pos_embed = None
- c3, c4, c5 = in_dims
- # ---------------- Input projs ----------------
- self.reduce_layer_1 = BasicConv(c5, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
- self.reduce_layer_2 = BasicConv(c4, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
- self.reduce_layer_3 = BasicConv(c3, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
- # ---------------- Downsample ----------------
- self.dowmsample_layer_1 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
- self.dowmsample_layer_2 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
- # ---------------- Transformer Encoder ----------------
- self.transformer_encoder = get_clones(
- TransformerLayer(self.out_dim, num_heads, mlp_ratio, dropout, en_act_type), num_layers)
- # ---------------- Top dwon FPN ----------------
- ## P5 -> P4
- self.top_down_layer_1 = RTCBlock(in_dim = self.out_dim * 2,
- out_dim = self.out_dim,
- num_blocks = round(3*depth),
- shortcut = False,
- act_type = act_type,
- norm_type = norm_type,
- depthwise = depthwise,
- )
- ## P4 -> P3
- self.top_down_layer_2 = RTCBlock(in_dim = self.out_dim * 2,
- out_dim = self.out_dim,
- num_blocks = round(3*depth),
- shortcut = False,
- act_type = act_type,
- norm_type = norm_type,
- depthwise = depthwise,
- )
-
- # ---------------- Bottom up PAN----------------
- ## P3 -> P4
- self.bottom_up_layer_1 = RTCBlock(in_dim = self.out_dim * 2,
- out_dim = self.out_dim,
- num_blocks = round(3*depth),
- shortcut = False,
- act_type = act_type,
- norm_type = norm_type,
- depthwise = depthwise,
- )
- ## P4 -> P5
- self.bottom_up_layer_2 = RTCBlock(in_dim = self.out_dim * 2,
- out_dim = self.out_dim,
- num_blocks = round(3*depth),
- shortcut = False,
- act_type = act_type,
- norm_type = norm_type,
- depthwise = depthwise,
- )
- self.init_weights()
-
- def init_weights(self):
- """Initialize the parameters."""
- for m in self.modules():
- if isinstance(m, torch.nn.Conv2d):
- # In order to be consistent with the source code,
- # reset the Conv2d initialization parameters
- m.reset_parameters()
- def build_2d_sincos_position_embedding(self, w, h, embed_dim=256, temperature=10000.):
- assert embed_dim % 4 == 0, \
- 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
-
- # ----------- Check cahed pos_embed -----------
- if self.pos_embed is not None and \
- self.pos_embed.shape[2:] == [h, w]:
- return self.pos_embed
-
- # ----------- Generate grid coords -----------
- grid_w = torch.arange(int(w), dtype=torch.float32)
- grid_h = torch.arange(int(h), dtype=torch.float32)
- grid_w, grid_h = torch.meshgrid([grid_w, grid_h]) # shape: [H, W]
- pos_dim = embed_dim // 4
- omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
- omega = 1. / (temperature**omega)
- out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
- out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
- # shape: [1, N, C]
- pos_embed = torch.concat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], axis=1)[None, :, :]
- self.pos_embed = pos_embed
- return pos_embed
- def forward(self, features):
- c3, c4, c5 = features
- # -------- Input projs --------
- p5 = self.reduce_layer_1(c5)
- p4 = self.reduce_layer_2(c4)
- p3 = self.reduce_layer_3(c3)
- # -------- Transformer encoder --------
- if self.transformer_encoder is not None:
- for encoder in self.transformer_encoder:
- channels, fmp_h, fmp_w = p5.shape[1:]
- # [B, C, H, W] -> [B, N, C], N=HxW
- src_flatten = p5.flatten(2).permute(0, 2, 1)
- pos_embed = self.build_2d_sincos_position_embedding(
- fmp_w, fmp_h, channels, self.pe_temperature)
- memory = encoder(src_flatten, pos_embed=pos_embed)
- # [B, N, C] -> [B, C, N] -> [B, C, H, W]
- p5 = memory.permute(0, 2, 1).reshape([-1, channels, fmp_h, fmp_w])
- # -------- Top down FPN --------
- p5_up = F.interpolate(p5, scale_factor=2.0)
- p4 = self.top_down_layer_1(torch.cat([p4, p5_up], dim=1))
- p4_up = F.interpolate(p4, scale_factor=2.0)
- p3 = self.top_down_layer_2(torch.cat([p3, p4_up], dim=1))
- # -------- Bottom up PAN --------
- p3_ds = self.dowmsample_layer_1(p3)
- p4 = self.bottom_up_layer_1(torch.cat([p4, p3_ds], dim=1))
- p4_ds = self.dowmsample_layer_2(p4)
- p5 = self.bottom_up_layer_2(torch.cat([p5, p4_ds], dim=1))
- out_feats = [p3, p4, p5]
-
- return out_feats
- if __name__ == '__main__':
- import time
- from thop import profile
- cfg = {
- 'width': 1.0,
- 'depth': 1.0,
- 'fpn': 'hybrid_encoder',
- 'fpn_act': 'silu',
- 'fpn_norm': 'BN',
- 'fpn_depthwise': False,
- 'en_num_heads': 8,
- 'en_num_layers': 1,
- 'en_mlp_ratio': 4.0,
- 'en_dropout': 0.1,
- 'pe_temperature': 10000.,
- 'en_act': 'gelu',
- }
- fpn_dims = [256, 512, 1024]
- out_dim = 256
- pyramid_feats = [torch.randn(1, fpn_dims[0], 80, 80), torch.randn(1, fpn_dims[1], 40, 40), torch.randn(1, fpn_dims[2], 20, 20)]
- model = build_fpn(cfg, fpn_dims, out_dim)
- t0 = time.time()
- outputs = model(pyramid_feats)
- t1 = time.time()
- print('Time: ', t1 - t0)
- for out in outputs:
- print(out.shape)
- print('==============================')
- flops, params = profile(model, inputs=(pyramid_feats, ), verbose=False)
- print('==============================')
- print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
- print('Params : {:.2f} M'.format(params / 1e6))
|