yolov8_backbone.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolov8_basic import BasicConv, ELANLayer
  5. except:
  6. from yolov8_basic import BasicConv, ELANLayer
  7. # IN1K pretrained weight
  8. pretrained_urls = {
  9. 'n': "https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/elandarknet_n_in1k_62.1.pth",
  10. 's': "https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/elandarknet_s_in1k_71.3.pth",
  11. 'm': None,
  12. 'l': None,
  13. 'x': None,
  14. }
  15. # ---------------------------- Basic functions ----------------------------
  16. class Yolov8Backbone(nn.Module):
  17. def __init__(self, cfg):
  18. super(Yolov8Backbone, self).__init__()
  19. # ------------------ Basic setting ------------------
  20. self.model_scale = cfg.scale
  21. self.feat_dims = [round(64 * cfg.width),
  22. round(128 * cfg.width),
  23. round(256 * cfg.width),
  24. round(512 * cfg.width),
  25. round(512 * cfg.width * cfg.ratio)]
  26. # ------------------ Network setting ------------------
  27. ## P1/2
  28. self.layer_1 = BasicConv(3, self.feat_dims[0],
  29. kernel_size=3, padding=1, stride=2,
  30. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise)
  31. # P2/4
  32. self.layer_2 = nn.Sequential(
  33. BasicConv(self.feat_dims[0], self.feat_dims[1],
  34. kernel_size=3, padding=1, stride=2,
  35. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  36. ELANLayer(in_dim = self.feat_dims[1],
  37. out_dim = self.feat_dims[1],
  38. num_blocks = round(3*cfg.depth),
  39. expansion = 0.5,
  40. shortcut = True,
  41. act_type = cfg.bk_act,
  42. norm_type = cfg.bk_norm,
  43. depthwise = cfg.bk_depthwise)
  44. )
  45. # P3/8
  46. self.layer_3 = nn.Sequential(
  47. BasicConv(self.feat_dims[1], self.feat_dims[2],
  48. kernel_size=3, padding=1, stride=2,
  49. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  50. ELANLayer(in_dim = self.feat_dims[2],
  51. out_dim = self.feat_dims[2],
  52. num_blocks = round(6*cfg.depth),
  53. expansion = 0.5,
  54. shortcut = True,
  55. act_type = cfg.bk_act,
  56. norm_type = cfg.bk_norm,
  57. depthwise = cfg.bk_depthwise)
  58. )
  59. # P4/16
  60. self.layer_4 = nn.Sequential(
  61. BasicConv(self.feat_dims[2], self.feat_dims[3],
  62. kernel_size=3, padding=1, stride=2,
  63. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  64. ELANLayer(in_dim = self.feat_dims[3],
  65. out_dim = self.feat_dims[3],
  66. num_blocks = round(6*cfg.depth),
  67. expansion = 0.5,
  68. shortcut = True,
  69. act_type = cfg.bk_act,
  70. norm_type = cfg.bk_norm,
  71. depthwise = cfg.bk_depthwise)
  72. )
  73. # P5/32
  74. self.layer_5 = nn.Sequential(
  75. BasicConv(self.feat_dims[3], self.feat_dims[4],
  76. kernel_size=3, padding=1, stride=2,
  77. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  78. ELANLayer(in_dim = self.feat_dims[4],
  79. out_dim = self.feat_dims[4],
  80. num_blocks = round(3*cfg.depth),
  81. expansion = 0.5,
  82. shortcut = True,
  83. act_type = cfg.bk_act,
  84. norm_type = cfg.bk_norm,
  85. depthwise = cfg.bk_depthwise)
  86. )
  87. # Initialize all layers
  88. self.init_weights()
  89. # Load imagenet pretrained weight
  90. if cfg.use_pretrained:
  91. self.load_pretrained()
  92. def init_weights(self):
  93. """Initialize the parameters."""
  94. for m in self.modules():
  95. if isinstance(m, torch.nn.Conv2d):
  96. # In order to be consistent with the source code,
  97. # reset the Conv2d initialization parameters
  98. m.reset_parameters()
  99. def load_pretrained(self):
  100. url = pretrained_urls[self.model_scale]
  101. if url is not None:
  102. print('Loading backbone pretrained weight from : {}'.format(url))
  103. # checkpoint state dict
  104. checkpoint = torch.hub.load_state_dict_from_url(
  105. url=url, map_location="cpu", check_hash=True)
  106. checkpoint_state_dict = checkpoint.pop("model")
  107. # model state dict
  108. model_state_dict = self.state_dict()
  109. # check
  110. for k in list(checkpoint_state_dict.keys()):
  111. if k in model_state_dict:
  112. shape_model = tuple(model_state_dict[k].shape)
  113. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  114. if shape_model != shape_checkpoint:
  115. checkpoint_state_dict.pop(k)
  116. else:
  117. checkpoint_state_dict.pop(k)
  118. print('Unused key: ', k)
  119. # load the weight
  120. self.load_state_dict(checkpoint_state_dict)
  121. else:
  122. print('No pretrained weight for model scale: {}.'.format(self.model_scale))
  123. def forward(self, x):
  124. c1 = self.layer_1(x)
  125. c2 = self.layer_2(c1)
  126. c3 = self.layer_3(c2)
  127. c4 = self.layer_4(c3)
  128. c5 = self.layer_5(c4)
  129. outputs = [c3, c4, c5]
  130. return outputs
  131. # ---------------------------- Functions ----------------------------
  132. ## build Yolo's Backbone
  133. def build_backbone(cfg):
  134. # model
  135. backbone = Yolov8Backbone(cfg)
  136. return backbone
  137. if __name__ == '__main__':
  138. import time
  139. from thop import profile
  140. class BaseConfig(object):
  141. def __init__(self) -> None:
  142. self.bk_act = 'silu'
  143. self.bk_norm = 'BN'
  144. self.bk_depthwise = False
  145. self.use_pretrained = True
  146. self.width = 0.50
  147. self.depth = 0.34
  148. self.ratio = 2.0
  149. self.scale = "s"
  150. cfg = BaseConfig()
  151. model = build_backbone(cfg)
  152. x = torch.randn(1, 3, 640, 640)
  153. t0 = time.time()
  154. outputs = model(x)
  155. t1 = time.time()
  156. print('Time: ', t1 - t0)
  157. for out in outputs:
  158. print(out.shape)
  159. x = torch.randn(1, 3, 640, 640)
  160. print('==============================')
  161. flops, params = profile(model, inputs=(x, ), verbose=False)
  162. print('==============================')
  163. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  164. print('Params : {:.2f} M'.format(params / 1e6))