| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- try:
- from .basic_modules.basic import BasicConv, UpSampleWrapper
- from .basic_modules.backbone import build_backbone
- from .basic_modules.transformer import TransformerEncoder
- except:
- from basic_modules.basic import BasicConv, UpSampleWrapper
- from basic_modules.backbone import build_backbone
- from basic_modules.transformer import TransformerEncoder
- # ----------------- Image Encoder -----------------
- def build_image_encoder(cfg):
- return ImageEncoder(cfg)
- class ImageEncoder(nn.Module):
- def __init__(self, cfg):
- super().__init__()
- # ---------------- Basic settings ----------------
- ## Basic parameters
- self.cfg = cfg
- ## Network parameters
- self.stride = cfg['out_stride']
- self.upsample_factor = 32 // self.stride
- self.hidden_dim = cfg['hidden_dim']
-
- # ---------------- Network settings ----------------
- ## Backbone Network
- self.backbone, fpn_feat_dims = build_backbone(cfg, pretrained=cfg['pretrained']&self.training)
- ## Input projection
- self.input_proj = BasicConv(fpn_feat_dims[-1], cfg['hidden_dim'], kernel_size=1, act_type=None, norm_type='BN')
- # ---------------- Transformer Encoder ----------------
- self.transformer_encoder = TransformerEncoder(d_model = cfg['hidden_dim'],
- num_heads = cfg['en_num_heads'],
- num_layers = cfg['en_num_layers'],
- ffn_dim = cfg['en_ffn_dim'],
- dropout = cfg['en_dropout'],
- act_type = cfg['en_act']
- )
- ## Upsample layer
- self.upsample = UpSampleWrapper(cfg['hidden_dim'], self.upsample_factor)
-
- ## Output projection
- self.output_proj = BasicConv(cfg['hidden_dim'], cfg['hidden_dim'], kernel_size=3, padding=1, act_type='silu', norm_type='BN')
- def forward(self, x):
- pyramid_feats = self.backbone(x)
- feat = self.input_proj(pyramid_feats[-1])
- feat = self.transformer_encoder(feat)
- feat = self.upsample(feat)
- feat = self.output_proj(feat)
- return feat
- if __name__ == '__main__':
- import time
- from thop import profile
- cfg = {
- 'width': 1.0,
- 'depth': 1.0,
- 'out_stride': 16,
- # Image Encoder - Backbone
- 'backbone': 'resnet50',
- 'backbone_norm': 'FrozeBN',
- 'pretrained': True,
- 'freeze_at': 0,
- 'freeze_stem_only': False,
- 'hidden_dim': 256,
- 'en_num_heads': 8,
- 'en_num_layers': 1,
- 'en_ffn_dim': 1024,
- 'en_dropout': 0.0,
- 'en_act': 'gelu',
- }
- x = torch.rand(2, 3, 640, 640)
- model = build_image_encoder(cfg)
- model.train()
- t0 = time.time()
- outputs = model(x)
- t1 = time.time()
- print('Time: ', t1 - t0)
- print(outputs.shape)
- print('==============================')
- model.eval()
- x = torch.rand(1, 3, 640, 640)
- flops, params = profile(model, inputs=(x, ), verbose=False)
- print('==============================')
- print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
- print('Params : {:.2f} M'.format(params / 1e6))
|