rtcdet_backbone.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .rtcdet_basic import BasicConv, RTCBlock
  5. except:
  6. from rtcdet_basic import BasicConv, RTCBlock
  7. # ---------------------------- Basic functions ----------------------------
  8. ## YOLOv8's backbone
  9. class RTCBackbone(nn.Module):
  10. def __init__(self, width=1.0, depth=1.0, act_type='silu', norm_type='BN', depthwise=False):
  11. super(RTCBackbone, self).__init__()
  12. self.feat_dims = [round(64 * width), round(128 * width), round(256 * width), round(512 * width), round(1024 * width)]
  13. # P1/2
  14. self.layer_1 = BasicConv(3, self.feat_dims[0], kernel_size=6, padding=2, stride=2, act_type=act_type, norm_type=norm_type)
  15. # P2/4
  16. self.layer_2 = nn.Sequential(
  17. BasicConv(self.feat_dims[0], self.feat_dims[1],
  18. kernel_size=3, padding=1, stride=2,
  19. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  20. RTCBlock(in_dim = self.feat_dims[1],
  21. out_dim = self.feat_dims[1],
  22. num_blocks = round(3*depth),
  23. shortcut = True,
  24. act_type = act_type,
  25. norm_type = norm_type,
  26. depthwise = depthwise)
  27. )
  28. # P3/8
  29. self.layer_3 = nn.Sequential(
  30. BasicConv(self.feat_dims[1], self.feat_dims[2],
  31. kernel_size=3, padding=1, stride=2,
  32. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  33. RTCBlock(in_dim = self.feat_dims[2],
  34. out_dim = self.feat_dims[2],
  35. num_blocks = round(9*depth),
  36. shortcut = True,
  37. act_type = act_type,
  38. norm_type = norm_type,
  39. depthwise = depthwise)
  40. )
  41. # P4/16
  42. self.layer_4 = nn.Sequential(
  43. BasicConv(self.feat_dims[2], self.feat_dims[3],
  44. kernel_size=3, padding=1, stride=2,
  45. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  46. RTCBlock(in_dim = self.feat_dims[3],
  47. out_dim = self.feat_dims[3],
  48. num_blocks = round(9*depth),
  49. shortcut = True,
  50. act_type = act_type,
  51. norm_type = norm_type,
  52. depthwise = depthwise)
  53. )
  54. # P5/32
  55. self.layer_5 = nn.Sequential(
  56. BasicConv(self.feat_dims[3], self.feat_dims[4],
  57. kernel_size=3, padding=1, stride=2,
  58. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  59. RTCBlock(in_dim = self.feat_dims[4],
  60. out_dim = self.feat_dims[4],
  61. num_blocks = round(3*depth),
  62. shortcut = True,
  63. act_type = act_type,
  64. norm_type = norm_type,
  65. depthwise = depthwise)
  66. )
  67. self.init_weights()
  68. def init_weights(self):
  69. """Initialize the parameters."""
  70. for m in self.modules():
  71. if isinstance(m, torch.nn.Conv2d):
  72. # In order to be consistent with the source code,
  73. # reset the Conv2d initialization parameters
  74. m.reset_parameters()
  75. def forward(self, x):
  76. c1 = self.layer_1(x)
  77. c2 = self.layer_2(c1)
  78. c3 = self.layer_3(c2)
  79. c4 = self.layer_4(c3)
  80. c5 = self.layer_5(c4)
  81. outputs = [c3, c4, c5]
  82. return outputs
  83. # ---------------------------- Functions ----------------------------
  84. ## build Yolov8's Backbone
  85. def build_backbone(cfg):
  86. # model
  87. backbone = RTCBackbone(width=cfg['width'],
  88. depth=cfg['depth'],
  89. act_type=cfg['bk_act'],
  90. norm_type=cfg['bk_norm'],
  91. depthwise=cfg['bk_depthwise']
  92. )
  93. feat_dims = backbone.feat_dims[-3:]
  94. return backbone, feat_dims
  95. if __name__ == '__main__':
  96. import time
  97. from thop import profile
  98. cfg = {
  99. 'bk_act': 'silu',
  100. 'bk_norm': 'BN',
  101. 'bk_depthwise': False,
  102. 'width': 1.0,
  103. 'depth': 1.0,
  104. 'ratio': 1.0,
  105. }
  106. model, feats = build_backbone(cfg)
  107. x = torch.randn(1, 3, 640, 640)
  108. t0 = time.time()
  109. outputs = model(x)
  110. t1 = time.time()
  111. print('Time: ', t1 - t0)
  112. for out in outputs:
  113. print(out.shape)
  114. x = torch.randn(1, 3, 640, 640)
  115. print('==============================')
  116. flops, params = profile(model, inputs=(x, ), verbose=False)
  117. print('==============================')
  118. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  119. print('Params : {:.2f} M'.format(params / 1e6))