rtdetr_encoder.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. try:
  5. from .basic_modules.backbone import build_backbone
  6. from .basic_modules.fpn import build_fpn
  7. except:
  8. from basic_modules.backbone import build_backbone
  9. from basic_modules.fpn import build_fpn
  10. # ----------------- Image Encoder -----------------
  11. def build_image_encoder(cfg):
  12. return ImageEncoder(cfg)
  13. class ImageEncoder(nn.Module):
  14. def __init__(self, cfg):
  15. super().__init__()
  16. # ---------------- Basic settings ----------------
  17. ## Basic parameters
  18. self.cfg = cfg
  19. ## Network parameters
  20. self.strides = cfg['out_stride']
  21. self.hidden_dim = cfg['hidden_dim']
  22. self.num_levels = len(self.strides)
  23. # ---------------- Network settings ----------------
  24. ## Backbone Network
  25. self.backbone, fpn_feat_dims = build_backbone(cfg, pretrained=cfg['pretrained']&self.training)
  26. ## Feature Pyramid Network
  27. self.fpn = build_fpn(cfg, fpn_feat_dims, self.hidden_dim)
  28. self.fpn_dims = self.fpn.out_dims
  29. def forward(self, x):
  30. pyramid_feats = self.backbone(x)
  31. pyramid_feats = self.fpn(pyramid_feats)
  32. return pyramid_feats
  33. if __name__ == '__main__':
  34. import time
  35. from thop import profile
  36. cfg = {
  37. 'width': 1.0,
  38. 'depth': 1.0,
  39. 'out_stride': [8, 16, 32],
  40. # Image Encoder - Backbone
  41. 'backbone': 'resnet18',
  42. 'backbone_norm': 'BN',
  43. 'res5_dilation': False,
  44. 'pretrained': True,
  45. 'pretrained_weight': 'imagenet1k_v1',
  46. # Image Encoder - FPN
  47. 'fpn': 'hybrid_encoder',
  48. 'fpn_act': 'silu',
  49. 'fpn_norm': 'BN',
  50. 'fpn_depthwise': False,
  51. 'hidden_dim': 256,
  52. 'en_num_heads': 8,
  53. 'en_num_layers': 1,
  54. 'en_mlp_ratio': 4.0,
  55. 'en_dropout': 0.1,
  56. 'pe_temperature': 10000.,
  57. 'en_act': 'gelu',
  58. }
  59. x = torch.rand(2, 3, 640, 640)
  60. model = build_image_encoder(cfg)
  61. model.train()
  62. t0 = time.time()
  63. outputs = model(x)
  64. t1 = time.time()
  65. print('Time: ', t1 - t0)
  66. for out in outputs:
  67. print(out.shape)
  68. print('==============================')
  69. model.eval()
  70. x = torch.rand(1, 3, 640, 640)
  71. flops, params = profile(model, inputs=(x, ), verbose=False)
  72. print('==============================')
  73. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  74. print('Params : {:.2f} M'.format(params / 1e6))