yolov8_backbone.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolov8_basic import BasicConv, ELANLayer
  5. except:
  6. from yolov8_basic import BasicConv, ELANLayer
  7. # ---------------------------- Basic functions ----------------------------
  8. class Yolov8Backbone(nn.Module):
  9. def __init__(self, cfg):
  10. super(Yolov8Backbone, self).__init__()
  11. # ------------------ Basic setting ------------------
  12. self.model_scale = cfg.scale
  13. self.feat_dims = [round(64 * cfg.width),
  14. round(128 * cfg.width),
  15. round(256 * cfg.width),
  16. round(512 * cfg.width),
  17. round(512 * cfg.width * cfg.ratio)]
  18. # ------------------ Network setting ------------------
  19. ## P1/2
  20. self.layer_1 = BasicConv(3, self.feat_dims[0],
  21. kernel_size=6, padding=2, stride=2,
  22. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise)
  23. # P2/4
  24. self.layer_2 = nn.Sequential(
  25. BasicConv(self.feat_dims[0], self.feat_dims[1],
  26. kernel_size=3, padding=1, stride=2,
  27. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  28. ELANLayer(in_dim = self.feat_dims[1],
  29. out_dim = self.feat_dims[1],
  30. num_blocks = round(3*cfg.depth),
  31. expansion = 0.5,
  32. shortcut = True,
  33. act_type = cfg.bk_act,
  34. norm_type = cfg.bk_norm,
  35. depthwise = cfg.bk_depthwise)
  36. )
  37. # P3/8
  38. self.layer_3 = nn.Sequential(
  39. BasicConv(self.feat_dims[1], self.feat_dims[2],
  40. kernel_size=3, padding=1, stride=2,
  41. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  42. ELANLayer(in_dim = self.feat_dims[2],
  43. out_dim = self.feat_dims[2],
  44. num_blocks = round(6*cfg.depth),
  45. expansion = 0.5,
  46. shortcut = True,
  47. act_type = cfg.bk_act,
  48. norm_type = cfg.bk_norm,
  49. depthwise = cfg.bk_depthwise)
  50. )
  51. # P4/16
  52. self.layer_4 = nn.Sequential(
  53. BasicConv(self.feat_dims[2], self.feat_dims[3],
  54. kernel_size=3, padding=1, stride=2,
  55. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  56. ELANLayer(in_dim = self.feat_dims[3],
  57. out_dim = self.feat_dims[3],
  58. num_blocks = round(6*cfg.depth),
  59. expansion = 0.5,
  60. shortcut = True,
  61. act_type = cfg.bk_act,
  62. norm_type = cfg.bk_norm,
  63. depthwise = cfg.bk_depthwise)
  64. )
  65. # P5/32
  66. self.layer_5 = nn.Sequential(
  67. BasicConv(self.feat_dims[3], self.feat_dims[4],
  68. kernel_size=3, padding=1, stride=2,
  69. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  70. ELANLayer(in_dim = self.feat_dims[4],
  71. out_dim = self.feat_dims[4],
  72. num_blocks = round(3*cfg.depth),
  73. expansion = 0.5,
  74. shortcut = True,
  75. act_type = cfg.bk_act,
  76. norm_type = cfg.bk_norm,
  77. depthwise = cfg.bk_depthwise)
  78. )
  79. # Initialize all layers
  80. self.init_weights()
  81. def init_weights(self):
  82. """Initialize the parameters."""
  83. for m in self.modules():
  84. if isinstance(m, torch.nn.Conv2d):
  85. # In order to be consistent with the source code,
  86. # reset the Conv2d initialization parameters
  87. m.reset_parameters()
  88. def forward(self, x):
  89. c1 = self.layer_1(x)
  90. c2 = self.layer_2(c1)
  91. c3 = self.layer_3(c2)
  92. c4 = self.layer_4(c3)
  93. c5 = self.layer_5(c4)
  94. outputs = [c3, c4, c5]
  95. return outputs
  96. # ---------------------------- Functions ----------------------------
  97. ## build Yolo's Backbone
  98. def build_backbone(cfg):
  99. # model
  100. backbone = Yolov8Backbone(cfg)
  101. return backbone
  102. if __name__ == '__main__':
  103. import time
  104. from thop import profile
  105. class BaseConfig(object):
  106. def __init__(self) -> None:
  107. self.bk_act = 'silu'
  108. self.bk_norm = 'BN'
  109. self.bk_depthwise = False
  110. self.width = 1.0
  111. self.depth = 1.0
  112. self.ratio = 1.0
  113. self.scale = "n"
  114. cfg = BaseConfig()
  115. model = build_backbone(cfg)
  116. x = torch.randn(1, 3, 640, 640)
  117. t0 = time.time()
  118. outputs = model(x)
  119. t1 = time.time()
  120. print('Time: ', t1 - t0)
  121. for out in outputs:
  122. print(out.shape)
  123. x = torch.randn(1, 3, 640, 640)
  124. print('==============================')
  125. flops, params = profile(model, inputs=(x, ), verbose=False)
  126. print('==============================')
  127. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  128. print('Params : {:.2f} M'.format(params / 1e6))