yolov6_backbone.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolov6_basic import RepBlock, RepVGGBlock, RepCSPBlock
  5. except:
  6. from yolov6_basic import RepBlock, RepVGGBlock, RepCSPBlock
  7. # IN1K pretrained weight
  8. pretrained_urls = {
  9. 'n': None,
  10. 's': None,
  11. 'm': None,
  12. 'l': None,
  13. }
  14. # --------------------- Yolov3's Backbone -----------------------
  15. ## Modified DarkNet
  16. class Yolov6Backbone(nn.Module):
  17. def __init__(self, cfg):
  18. super(Yolov6Backbone, self).__init__()
  19. # ------------------ Basic setting ------------------
  20. self.cfg = cfg
  21. self.model_scale = cfg.scale
  22. self.feat_dims = [round(64 * cfg.width),
  23. round(128 * cfg.width),
  24. round(256 * cfg.width),
  25. round(512 * cfg.width),
  26. round(1024 * cfg.width)]
  27. # ------------------ Network setting ------------------
  28. ## P1/2
  29. self.layer_1 = RepVGGBlock(3, self.feat_dims[0],
  30. kernel_size=3, padding=1, stride=2)
  31. # P2/4
  32. self.layer_2 = self.make_block(self.feat_dims[0], self.feat_dims[1], round(6*cfg.depth))
  33. # P3/8
  34. self.layer_3 = self.make_block(self.feat_dims[1], self.feat_dims[2], round(12*cfg.depth))
  35. # P4/16
  36. self.layer_4 = self.make_block(self.feat_dims[2], self.feat_dims[3], round(18*cfg.depth))
  37. # P5/32
  38. self.layer_5 = self.make_block(self.feat_dims[3], self.feat_dims[4], round(6*cfg.depth))
  39. # Initialize all layers
  40. self.init_weights()
  41. # Load imagenet pretrained weight
  42. if cfg.use_pretrained:
  43. self.load_pretrained()
  44. def init_weights(self):
  45. """Initialize the parameters."""
  46. for m in self.modules():
  47. if isinstance(m, torch.nn.Conv2d):
  48. # In order to be consistent with the source code,
  49. # reset the Conv2d initialization parameters
  50. m.reset_parameters()
  51. def load_pretrained(self):
  52. url = pretrained_urls[self.model_scale]
  53. if url is not None:
  54. print('Loading backbone pretrained weight from : {}'.format(url))
  55. # checkpoint state dict
  56. checkpoint = torch.hub.load_state_dict_from_url(
  57. url=url, map_location="cpu", check_hash=True)
  58. checkpoint_state_dict = checkpoint.pop("model")
  59. # model state dict
  60. model_state_dict = self.state_dict()
  61. # check
  62. for k in list(checkpoint_state_dict.keys()):
  63. if k in model_state_dict:
  64. shape_model = tuple(model_state_dict[k].shape)
  65. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  66. if shape_model != shape_checkpoint:
  67. checkpoint_state_dict.pop(k)
  68. else:
  69. checkpoint_state_dict.pop(k)
  70. print('Unused key: ', k)
  71. # load the weight
  72. self.load_state_dict(checkpoint_state_dict)
  73. else:
  74. print('No pretrained weight for model scale: {}.'.format(self.model_scale))
  75. def make_block(self, in_dim, out_dim, num_blocks=1):
  76. if self.model_scale in ["n", "s"]:
  77. block = nn.Sequential(
  78. RepVGGBlock(in_dim, out_dim,
  79. kernel_size=3, padding=1, stride=2),
  80. RepBlock(in_channels = out_dim,
  81. out_channels = out_dim,
  82. num_blocks = num_blocks,
  83. block = RepVGGBlock)
  84. )
  85. elif self.model_scale in ["m", "l"]:
  86. block = nn.Sequential(
  87. RepVGGBlock(in_dim, out_dim,
  88. kernel_size=3, padding=1, stride=2),
  89. RepCSPBlock(in_channels = out_dim,
  90. out_channels = out_dim,
  91. num_blocks = num_blocks,
  92. expansion = self.cfg.bk_csp_expansion)
  93. )
  94. else:
  95. raise NotImplementedError("Unknown model scale: {}".format(self.model_scale))
  96. return block
  97. def forward(self, x):
  98. c1 = self.layer_1(x)
  99. c2 = self.layer_2(c1)
  100. c3 = self.layer_3(c2)
  101. c4 = self.layer_4(c3)
  102. c5 = self.layer_5(c4)
  103. outputs = [c3, c4, c5]
  104. return outputs
  105. if __name__ == '__main__':
  106. import time
  107. from thop import profile
  108. class BaseConfig(object):
  109. def __init__(self) -> None:
  110. self.bk_depthwise = False
  111. self.width = 0.50
  112. self.depth = 0.34
  113. self.scale = "s"
  114. self.use_pretrained = True
  115. cfg = BaseConfig()
  116. model = Yolov6Backbone(cfg)
  117. x = torch.randn(1, 3, 640, 640)
  118. t0 = time.time()
  119. outputs = model(x)
  120. t1 = time.time()
  121. print('Time: ', t1 - t0)
  122. for out in outputs:
  123. print(out.shape)
  124. for m in model.modules():
  125. if hasattr(m, "switch_to_deploy"):
  126. m.switch_to_deploy()
  127. x = torch.randn(1, 3, 640, 640)
  128. print('==============================')
  129. flops, params = profile(model, inputs=(x, ), verbose=False)
  130. print('==============================')
  131. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  132. print('Params : {:.2f} M'.format(params / 1e6))