rtpdetr_encoder.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. try:
  5. from .basic_modules.basic import BasicConv, UpSampleWrapper
  6. from .basic_modules.backbone import build_backbone
  7. except:
  8. from basic_modules.basic import BasicConv, UpSampleWrapper
  9. from basic_modules.backbone import build_backbone
  10. # ----------------- Image Encoder -----------------
  11. def build_image_encoder(cfg):
  12. return ImageEncoder(cfg)
  13. class ImageEncoder(nn.Module):
  14. def __init__(self, cfg):
  15. super().__init__()
  16. # ---------------- Basic settings ----------------
  17. ## Basic parameters
  18. self.cfg = cfg
  19. ## Network parameters
  20. self.stride = cfg['out_stride']
  21. self.upsample_factor = 32 // self.stride
  22. self.hidden_dim = cfg['hidden_dim']
  23. # ---------------- Network settings ----------------
  24. ## Backbone Network
  25. self.backbone, fpn_feat_dims = build_backbone(cfg, pretrained=cfg['pretrained']&self.training)
  26. ## Upsample layer
  27. self.upsample = UpSampleWrapper(fpn_feat_dims[-1], self.upsample_factor)
  28. ## Input projection
  29. self.input_proj = BasicConv(self.upsample.out_dim, self.hidden_dim, kernel_size=1, act_type=None, norm_type='BN')
  30. def forward(self, x):
  31. pyramid_feats = self.backbone(x)
  32. feat = self.upsample(pyramid_feats[-1])
  33. feat = self.input_proj(feat)
  34. return feat
  35. if __name__ == '__main__':
  36. import time
  37. from thop import profile
  38. cfg = {
  39. 'width': 1.0,
  40. 'depth': 1.0,
  41. 'out_stride': 16,
  42. # Image Encoder - Backbone
  43. 'backbone': 'resnet50',
  44. 'backbone_norm': 'BN',
  45. 'res5_dilation': False,
  46. 'pretrained': True,
  47. 'pretrained_weight': 'imagenet1k_v1',
  48. 'hidden_dim': 256,
  49. }
  50. x = torch.rand(2, 3, 640, 640)
  51. model = build_image_encoder(cfg)
  52. model.train()
  53. t0 = time.time()
  54. outputs = model(x)
  55. t1 = time.time()
  56. print('Time: ', t1 - t0)
  57. print(outputs.shape)
  58. print('==============================')
  59. model.eval()
  60. x = torch.rand(1, 3, 640, 640)
  61. flops, params = profile(model, inputs=(x, ), verbose=False)
  62. print('==============================')
  63. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  64. print('Params : {:.2f} M'.format(params / 1e6))