yolov8_backbone.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import ConvModule, C2fBlock
  5. except:
  6. from modules import ConvModule, C2fBlock
  7. # IN1K pretrained weight
  8. pretrained_urls = {
  9. 'n': "https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/in1k_pretrained_weight/elandarknet_n_in1k_62.1.pth",
  10. 's': "https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/in1k_pretrained_weight/elandarknet_s_in1k_71.3.pth",
  11. 'm': "https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/in1k_pretrained_weight/elandarknet_m_in1k_75.7.pth",
  12. 'l': "https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/in1k_pretrained_weight/elandarknet_l_in1k_77.3.pth",
  13. 'x': None,
  14. }
  15. # ---------------------------- Basic functions ----------------------------
  16. class Yolov8Backbone(nn.Module):
  17. def __init__(self, cfg):
  18. super(Yolov8Backbone, self).__init__()
  19. # ------------------ Basic setting ------------------
  20. self.model_scale = cfg.model_scale
  21. self.feat_dims = [round(64 * cfg.width),
  22. round(128 * cfg.width),
  23. round(256 * cfg.width),
  24. round(512 * cfg.width),
  25. round(512 * cfg.width * cfg.ratio)]
  26. # ------------------ Network setting ------------------
  27. ## P1/2
  28. self.layer_1 = ConvModule(3, self.feat_dims[0], kernel_size=3, padding=1, stride=2)
  29. # P2/4
  30. self.layer_2 = nn.Sequential(
  31. ConvModule(self.feat_dims[0], self.feat_dims[1], kernel_size=3, padding=1, stride=2),
  32. C2fBlock(in_dim = self.feat_dims[1],
  33. out_dim = self.feat_dims[1],
  34. num_blocks = round(3*cfg.depth),
  35. expansion = 0.5,
  36. shortcut = True,
  37. )
  38. )
  39. # P3/8
  40. self.layer_3 = nn.Sequential(
  41. ConvModule(self.feat_dims[1], self.feat_dims[2], kernel_size=3, padding=1, stride=2),
  42. C2fBlock(in_dim = self.feat_dims[2],
  43. out_dim = self.feat_dims[2],
  44. num_blocks = round(6*cfg.depth),
  45. expansion = 0.5,
  46. shortcut = True,
  47. )
  48. )
  49. # P4/16
  50. self.layer_4 = nn.Sequential(
  51. ConvModule(self.feat_dims[2], self.feat_dims[3], kernel_size=3, padding=1, stride=2),
  52. C2fBlock(in_dim = self.feat_dims[3],
  53. out_dim = self.feat_dims[3],
  54. num_blocks = round(6*cfg.depth),
  55. expansion = 0.5,
  56. shortcut = True,
  57. )
  58. )
  59. # P5/32
  60. self.layer_5 = nn.Sequential(
  61. ConvModule(self.feat_dims[3], self.feat_dims[4], kernel_size=3, padding=1, stride=2),
  62. C2fBlock(in_dim = self.feat_dims[4],
  63. out_dim = self.feat_dims[4],
  64. num_blocks = round(3*cfg.depth),
  65. expansion = 0.5,
  66. shortcut = True,
  67. )
  68. )
  69. # Initialize all layers
  70. self.init_weights()
  71. # Load imagenet pretrained weight
  72. if cfg.use_pretrained:
  73. self.load_pretrained()
  74. def init_weights(self):
  75. """Initialize the parameters."""
  76. for m in self.modules():
  77. if isinstance(m, torch.nn.Conv2d):
  78. m.reset_parameters()
  79. def load_pretrained(self):
  80. url = pretrained_urls[self.model_scale]
  81. if url is not None:
  82. print('Loading backbone pretrained weight from : {}'.format(url))
  83. # checkpoint state dict
  84. checkpoint = torch.hub.load_state_dict_from_url(
  85. url=url, map_location="cpu", check_hash=True)
  86. checkpoint_state_dict = checkpoint.pop("model")
  87. # model state dict
  88. model_state_dict = self.state_dict()
  89. # check
  90. for k in list(checkpoint_state_dict.keys()):
  91. if k in model_state_dict:
  92. shape_model = tuple(model_state_dict[k].shape)
  93. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  94. if shape_model != shape_checkpoint:
  95. checkpoint_state_dict.pop(k)
  96. else:
  97. checkpoint_state_dict.pop(k)
  98. print('Unused key: ', k)
  99. # load the weight
  100. self.load_state_dict(checkpoint_state_dict)
  101. else:
  102. print('No pretrained weight for model scale: {}.'.format(self.model_scale))
  103. def forward(self, x):
  104. c1 = self.layer_1(x)
  105. c2 = self.layer_2(c1)
  106. c3 = self.layer_3(c2)
  107. c4 = self.layer_4(c3)
  108. c5 = self.layer_5(c4)
  109. outputs = [c3, c4, c5]
  110. return outputs
  111. if __name__ == '__main__':
  112. import time
  113. from thop import profile
  114. # YOLOv8 config
  115. class BaseConfig(object):
  116. def __init__(self) -> None:
  117. self.use_pretrained = False
  118. self.width = 0.50
  119. self.depth = 0.34
  120. self.ratio = 2.00
  121. self.model_scale = "s"
  122. cfg = BaseConfig()
  123. # Build backbone
  124. model = Yolov8Backbone(cfg)
  125. # Randomly generate a input data
  126. x = torch.randn(2, 3, 640, 640)
  127. # Inference
  128. outputs = model(x)
  129. print(' - the shape of input : ', x.shape)
  130. for out in outputs:
  131. print(' - the shape of output : ', out.shape)
  132. x = torch.randn(1, 3, 640, 640)
  133. flops, params = profile(model, inputs=(x, ), verbose=False)
  134. print('============== FLOPs & Params ================')
  135. print(' - FLOPs : {:.2f} G'.format(flops / 1e9 * 2))
  136. print(' - Params : {:.2f} M'.format(params / 1e6))