rtpdetr_encoder.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. try:
  5. from .basic_modules.basic import BasicConv, UpSampleWrapper
  6. from .basic_modules.backbone import build_backbone
  7. from .basic_modules.transformer import TransformerEncoder
  8. except:
  9. from basic_modules.basic import BasicConv, UpSampleWrapper
  10. from basic_modules.backbone import build_backbone
  11. from basic_modules.transformer import TransformerEncoder
  12. # ----------------- Image Encoder -----------------
  13. def build_image_encoder(cfg):
  14. return ImageEncoder(cfg)
  15. class ImageEncoder(nn.Module):
  16. def __init__(self, cfg):
  17. super().__init__()
  18. # ---------------- Basic settings ----------------
  19. ## Basic parameters
  20. self.cfg = cfg
  21. ## Network parameters
  22. self.stride = cfg['out_stride']
  23. self.upsample_factor = 32 // self.stride
  24. self.hidden_dim = cfg['hidden_dim']
  25. # ---------------- Network settings ----------------
  26. ## Backbone Network
  27. self.backbone, fpn_feat_dims = build_backbone(cfg, pretrained=cfg['pretrained']&self.training)
  28. ## Input projection
  29. self.input_proj = BasicConv(fpn_feat_dims[-1], cfg['hidden_dim'], kernel_size=1, act_type=None, norm_type='BN')
  30. # ---------------- Transformer Encoder ----------------
  31. self.transformer_encoder = TransformerEncoder(d_model = cfg['hidden_dim'],
  32. num_heads = cfg['en_num_heads'],
  33. num_layers = cfg['en_num_layers'],
  34. ffn_dim = cfg['en_ffn_dim'],
  35. dropout = cfg['en_dropout'],
  36. act_type = cfg['en_act']
  37. )
  38. ## Upsample layer
  39. self.upsample = UpSampleWrapper(cfg['hidden_dim'], self.upsample_factor)
  40. ## Output projection
  41. self.output_proj = BasicConv(cfg['hidden_dim'], cfg['hidden_dim'], kernel_size=3, padding=1, act_type='silu', norm_type='BN')
  42. def forward(self, x):
  43. pyramid_feats = self.backbone(x)
  44. feat = self.input_proj(pyramid_feats[-1])
  45. feat = self.transformer_encoder(feat)
  46. feat = self.upsample(feat)
  47. feat = self.output_proj(feat)
  48. return feat
  49. if __name__ == '__main__':
  50. import time
  51. from thop import profile
  52. cfg = {
  53. 'width': 1.0,
  54. 'depth': 1.0,
  55. 'out_stride': 16,
  56. # Image Encoder - Backbone
  57. 'backbone': 'resnet50',
  58. 'backbone_norm': 'FrozeBN',
  59. 'pretrained': True,
  60. 'freeze_at': 0,
  61. 'freeze_stem_only': False,
  62. 'hidden_dim': 256,
  63. 'en_num_heads': 8,
  64. 'en_num_layers': 1,
  65. 'en_ffn_dim': 1024,
  66. 'en_dropout': 0.0,
  67. 'en_act': 'gelu',
  68. }
  69. x = torch.rand(2, 3, 640, 640)
  70. model = build_image_encoder(cfg)
  71. model.train()
  72. t0 = time.time()
  73. outputs = model(x)
  74. t1 = time.time()
  75. print('Time: ', t1 - t0)
  76. print(outputs.shape)
  77. print('==============================')
  78. model.eval()
  79. x = torch.rand(1, 3, 640, 640)
  80. flops, params = profile(model, inputs=(x, ), verbose=False)
  81. print('==============================')
  82. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  83. print('Params : {:.2f} M'.format(params / 1e6))