rtcdet_backbone.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .rtcdet_basic import BasicConv, RTCBlock
  5. except:
  6. from rtcdet_basic import BasicConv, RTCBlock
  7. # ---------------------------- Basic functions ----------------------------
  8. ## YOLOv8's backbone
  9. class RTCBackbone(nn.Module):
  10. def __init__(self, width=1.0, depth=1.0, ratio=1.0, act_type='silu', norm_type='BN', depthwise=False):
  11. super(RTCBackbone, self).__init__()
  12. self.feat_dims = [round(64 * width), round(128 * width), round(256 * width), round(512 * width), round(512 * width * ratio)]
  13. # P1/2
  14. self.layer_1 = BasicConv(3, self.feat_dims[0], kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
  15. # P2/4
  16. self.layer_2 = nn.Sequential(
  17. BasicConv(self.feat_dims[0], self.feat_dims[1],
  18. kernel_size=3, padding=1, stride=2,
  19. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  20. RTCBlock(in_dim = self.feat_dims[1],
  21. out_dim = self.feat_dims[1],
  22. num_blocks = round(3*depth),
  23. shortcut = True,
  24. act_type = act_type,
  25. norm_type = norm_type,
  26. depthwise = depthwise)
  27. )
  28. # P3/8
  29. self.layer_3 = nn.Sequential(
  30. BasicConv(self.feat_dims[1], self.feat_dims[2],
  31. kernel_size=3, padding=1, stride=2,
  32. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  33. RTCBlock(in_dim = self.feat_dims[2],
  34. out_dim = self.feat_dims[2],
  35. num_blocks = round(6*depth),
  36. shortcut = True,
  37. act_type = act_type,
  38. norm_type = norm_type,
  39. depthwise = depthwise)
  40. )
  41. # P4/16
  42. self.layer_4 = nn.Sequential(
  43. BasicConv(self.feat_dims[2], self.feat_dims[3],
  44. kernel_size=3, padding=1, stride=2,
  45. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  46. RTCBlock(in_dim = self.feat_dims[3],
  47. out_dim = self.feat_dims[3],
  48. num_blocks = round(6*depth),
  49. shortcut = True,
  50. act_type = act_type,
  51. norm_type = norm_type,
  52. depthwise = depthwise)
  53. )
  54. # P5/32
  55. self.layer_5 = nn.Sequential(
  56. BasicConv(self.feat_dims[3], self.feat_dims[4],
  57. kernel_size=3, padding=1, stride=2,
  58. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  59. RTCBlock(in_dim = self.feat_dims[4],
  60. out_dim = self.feat_dims[4],
  61. num_blocks = round(3*depth),
  62. shortcut = True,
  63. act_type = act_type,
  64. norm_type = norm_type,
  65. depthwise = depthwise)
  66. )
  67. self.init_weights()
  68. def init_weights(self):
  69. """Initialize the parameters."""
  70. for m in self.modules():
  71. if isinstance(m, torch.nn.Conv2d):
  72. # In order to be consistent with the source code,
  73. # reset the Conv2d initialization parameters
  74. m.reset_parameters()
  75. def forward(self, x):
  76. c1 = self.layer_1(x)
  77. c2 = self.layer_2(c1)
  78. c3 = self.layer_3(c2)
  79. c4 = self.layer_4(c3)
  80. c5 = self.layer_5(c4)
  81. outputs = [c3, c4, c5]
  82. return outputs
  83. # ---------------------------- Functions ----------------------------
  84. ## build Yolov8's Backbone
  85. def build_backbone(cfg):
  86. # model
  87. backbone = RTCBackbone(width=cfg['width'],
  88. depth=cfg['depth'],
  89. ratio=cfg['ratio'],
  90. act_type=cfg['bk_act'],
  91. norm_type=cfg['bk_norm'],
  92. depthwise=cfg['bk_depthwise']
  93. )
  94. feat_dims = backbone.feat_dims[-3:]
  95. return backbone, feat_dims
  96. if __name__ == '__main__':
  97. import time
  98. from thop import profile
  99. cfg = {
  100. 'bk_act': 'silu',
  101. 'bk_norm': 'BN',
  102. 'bk_depthwise': False,
  103. 'width': 1.0,
  104. 'depth': 1.0,
  105. 'ratio': 1.0,
  106. }
  107. model, feats = build_backbone(cfg)
  108. x = torch.randn(1, 3, 640, 640)
  109. t0 = time.time()
  110. outputs = model(x)
  111. t1 = time.time()
  112. print('Time: ', t1 - t0)
  113. for out in outputs:
  114. print(out.shape)
  115. x = torch.randn(1, 3, 640, 640)
  116. print('==============================')
  117. flops, params = profile(model, inputs=(x, ), verbose=False)
  118. print('==============================')
  119. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  120. print('Params : {:.2f} M'.format(params / 1e6))