vitdet_encoder.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .basic_modules.basic import BasicConv, UpSampleWrapper
  5. from .basic_modules.backbone import build_backbone
  6. except:
  7. from basic_modules.basic import BasicConv, UpSampleWrapper
  8. from basic_modules.backbone import build_backbone
  9. # ----------------- Image Encoder -----------------
  10. def build_image_encoder(cfg):
  11. return ImageEncoder(cfg)
  12. class ImageEncoder(nn.Module):
  13. def __init__(self, cfg):
  14. super().__init__()
  15. # ---------------- Basic settings ----------------
  16. ## Basic parameters
  17. self.cfg = cfg
  18. ## Network parameters
  19. self.stride = 16
  20. self.fpn_dims = [cfg['hidden_dim']] * 3
  21. self.hidden_dim = cfg['hidden_dim']
  22. # ---------------- Network settings ----------------
  23. ## Backbone Network
  24. self.backbone, backbone_dim = build_backbone(cfg, cfg['pretrained'])
  25. ## Input projection
  26. self.input_proj = BasicConv(backbone_dim, cfg['hidden_dim'],
  27. kernel_size=1,
  28. act_type=None, norm_type='BN')
  29. ## Upsample layer
  30. self.upsample = UpSampleWrapper(cfg['hidden_dim'], 2.0)
  31. ## Downsample layer
  32. self.downsample = BasicConv(cfg['hidden_dim'], cfg['hidden_dim'],
  33. kernel_size=3, padding=1, stride=2,
  34. act_type=None, norm_type='BN')
  35. ## Output projection
  36. self.output_projs = nn.ModuleList([BasicConv(cfg['hidden_dim'], cfg['hidden_dim'],
  37. kernel_size=3, padding=1,
  38. act_type='silu', norm_type='BN')
  39. ] * 3)
  40. def forward(self, x):
  41. # Backbone
  42. feat = self.backbone(x)
  43. # Input proj
  44. feat = self.input_proj(feat)
  45. # FPN
  46. feat_up = self.upsample(feat)
  47. feat_ds = self.downsample(feat)
  48. # Multi level features: [P3, P4, P5]
  49. pyramid_feats = [self.output_projs[0](feat_up),
  50. self.output_projs[1](feat),
  51. self.output_projs[2](feat_ds)]
  52. return pyramid_feats
  53. if __name__ == '__main__':
  54. import time
  55. from thop import profile
  56. cfg = {
  57. 'width': 1.0,
  58. 'depth': 1.0,
  59. 'out_stride': 16,
  60. 'hidden_dim': 256,
  61. # Image Encoder - Backbone
  62. 'backbone': 'resnet50',
  63. 'backbone_norm': 'FrozeBN',
  64. 'pretrained': True,
  65. 'freeze_at': 0,
  66. 'freeze_stem_only': False,
  67. }
  68. x = torch.rand(2, 3, 640, 640)
  69. model = build_image_encoder(cfg)
  70. model.train()
  71. t0 = time.time()
  72. outputs = model(x)
  73. t1 = time.time()
  74. print('Time: ', t1 - t0)
  75. print(outputs.shape)
  76. print('==============================')
  77. model.eval()
  78. x = torch.rand(1, 3, 640, 640)
  79. flops, params = profile(model, inputs=(x, ), verbose=False)
  80. print('==============================')
  81. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  82. print('Params : {:.2f} M'.format(params / 1e6))