yolo11_backbone.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import ConvModule, YoloStage, SPPF, C2PSA
  5. except:
  6. from modules import ConvModule, YoloStage, SPPF, C2PSA
  7. # ---------------------------- YOLO11 Backbone ----------------------------
  8. class Yolo11Backbone(nn.Module):
  9. def __init__(self, cfg):
  10. super(Yolo11Backbone, self).__init__()
  11. # ------------------ Basic setting ------------------
  12. self.model_scale = cfg.model_scale
  13. self.feat_dims = [int(512 * cfg.width), int(512 * cfg.width), int(512 * cfg.width * cfg.ratio)]
  14. # ------------------ Network setting ------------------
  15. ## P1/2
  16. self.layer_1 = ConvModule(3, int(64 * cfg.width), kernel_size=3, stride=2)
  17. # P2/4
  18. self.layer_2 = nn.Sequential(
  19. ConvModule(int(64 * cfg.width), int(128 * cfg.width), kernel_size=3, stride=2),
  20. YoloStage(in_dim = int(128 * cfg.width),
  21. out_dim = int(256 * cfg.width),
  22. num_blocks = round(2*cfg.depth),
  23. shortcut = True,
  24. expansion = 0.25,
  25. use_c3k = False if self.model_scale in "ns" else True,
  26. )
  27. )
  28. # P3/8
  29. self.layer_3 = nn.Sequential(
  30. ConvModule(int(256 * cfg.width), int(256 * cfg.width), kernel_size=3, stride=2),
  31. YoloStage(in_dim = int(256 * cfg.width),
  32. out_dim = int(512 * cfg.width),
  33. num_blocks = round(2*cfg.depth),
  34. shortcut = True,
  35. expansion = 0.25,
  36. use_c3k = False if self.model_scale in "ns" else True,
  37. )
  38. )
  39. # P4/16
  40. self.layer_4 = nn.Sequential(
  41. ConvModule(int(512 * cfg.width), int(512 * cfg.width), kernel_size=3, stride=2),
  42. YoloStage(in_dim = int(512 * cfg.width),
  43. out_dim = int(512 * cfg.width),
  44. num_blocks = round(2*cfg.depth),
  45. shortcut = True,
  46. expansion = 0.5,
  47. use_c3k = True,
  48. )
  49. )
  50. # P5/32
  51. self.layer_5 = nn.Sequential(
  52. ConvModule(int(512 * cfg.width), int(512 * cfg.width * cfg.ratio), kernel_size=3, stride=2),
  53. YoloStage(in_dim = int(512 * cfg.width * cfg.ratio),
  54. out_dim = int(512 * cfg.width * cfg.ratio),
  55. num_blocks = round(2*cfg.depth),
  56. shortcut = True,
  57. expansion = 0.5,
  58. use_c3k = True,
  59. )
  60. )
  61. # Extra module (no pretrained weight)
  62. self.layer_6 = SPPF(in_dim = int(512 * cfg.width * cfg.ratio),
  63. out_dim = int(512 * cfg.width * cfg.ratio),
  64. spp_pooling_size = 5,
  65. neck_expand_ratio = 0.5,
  66. )
  67. self.layer_7 = C2PSA(in_dim = int(512 * cfg.width * cfg.ratio),
  68. out_dim = int(512 * cfg.width * cfg.ratio),
  69. num_blocks = round(2*cfg.depth),
  70. expansion = 0.5,
  71. )
  72. # Initialize all layers
  73. self.init_weights()
  74. def init_weights(self):
  75. for m in self.modules():
  76. if isinstance(m, torch.nn.Conv2d):
  77. m.reset_parameters()
  78. def forward(self, x):
  79. c1 = self.layer_1(x)
  80. c2 = self.layer_2(c1)
  81. c3 = self.layer_3(c2)
  82. c4 = self.layer_4(c3)
  83. c5 = self.layer_5(c4)
  84. c5 = self.layer_6(c5)
  85. c5 = self.layer_7(c5)
  86. outputs = [c3, c4, c5]
  87. return outputs
  88. if __name__ == '__main__':
  89. import time
  90. from thop import profile
  91. class BaseConfig(object):
  92. def __init__(self) -> None:
  93. self.width = 0.25
  94. self.depth = 0.34
  95. self.ratio = 2.0
  96. self.model_scale = "n"
  97. cfg = BaseConfig()
  98. model = Yolo11Backbone(cfg)
  99. x = torch.randn(1, 3, 640, 640)
  100. t0 = time.time()
  101. outputs = model(x)
  102. t1 = time.time()
  103. print('Time: ', t1 - t0)
  104. for out in outputs:
  105. print(out.shape)
  106. x = torch.randn(1, 3, 640, 640)
  107. print('==============================')
  108. flops, params = profile(model, inputs=(x, ), verbose=False)
  109. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  110. print('Params : {:.2f} M'.format(params / 1e6))