yolov7_backbone.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import ConvModule, ELANBlock, DownSample
  5. except:
  6. from modules import ConvModule, ELANBlock, DownSample
  7. in1k_pretrained_urls = {
  8. "elannet_large": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/yolov7_elannet_large.pth",
  9. }
  10. # --------------------- Yolov7 backbone -----------------------
  11. class Yolov7Backbone(nn.Module):
  12. def __init__(self, use_pretrained: bool = False):
  13. super(Yolov7Backbone, self).__init__()
  14. self.feat_dims = [32, 64, 128, 256, 512, 1024, 1024]
  15. self.squeeze_ratios = [0.5, 0.5, 0.5, 0.25] # Stage-1 -> Stage-4
  16. self.branch_depths = [2, 2, 2, 2] # Stage-1 -> Stage-4
  17. self.use_pretrained = use_pretrained
  18. # -------------------- Network parameters --------------------
  19. ## P1/2
  20. self.layer_1 = nn.Sequential(
  21. ConvModule(3, self.feat_dims[0], kernel_size=3),
  22. ConvModule(self.feat_dims[0], self.feat_dims[1], kernel_size=3, stride=2),
  23. ConvModule(self.feat_dims[1], self.feat_dims[1], kernel_size=3)
  24. )
  25. ## P2/4: Stage-1
  26. self.layer_2 = nn.Sequential(
  27. ConvModule(self.feat_dims[1], self.feat_dims[2], kernel_size=3, stride=2),
  28. ELANBlock(in_dim = self.feat_dims[2],
  29. out_dim = self.feat_dims[3],
  30. expansion = self.squeeze_ratios[0],
  31. branch_depth = self.branch_depths[0],
  32. )
  33. )
  34. ## P3/8: Stage-2
  35. self.layer_3 = nn.Sequential(
  36. DownSample(self.feat_dims[3], self.feat_dims[3]),
  37. ELANBlock(in_dim = self.feat_dims[3],
  38. out_dim = self.feat_dims[4],
  39. expansion = self.squeeze_ratios[1],
  40. branch_depth = self.branch_depths[1],
  41. )
  42. )
  43. ## P4/16: Stage-3
  44. self.layer_4 = nn.Sequential(
  45. DownSample(self.feat_dims[4], self.feat_dims[4]),
  46. ELANBlock(in_dim = self.feat_dims[4],
  47. out_dim = self.feat_dims[5],
  48. expansion = self.squeeze_ratios[2],
  49. branch_depth = self.branch_depths[2],
  50. )
  51. )
  52. ## P5/32: Stage-4
  53. self.layer_5 = nn.Sequential(
  54. DownSample(self.feat_dims[5], self.feat_dims[5]),
  55. ELANBlock(in_dim = self.feat_dims[5],
  56. out_dim = self.feat_dims[6],
  57. expansion = self.squeeze_ratios[3],
  58. branch_depth = self.branch_depths[3],
  59. )
  60. )
  61. # Initialize all layers
  62. self.init_weights()
  63. def init_weights(self):
  64. """Initialize the parameters."""
  65. for m in self.modules():
  66. if isinstance(m, torch.nn.Conv2d):
  67. m.reset_parameters()
  68. # Load imagenet pretrained weight
  69. if self.use_pretrained:
  70. self.load_pretrained()
  71. def load_pretrained(self):
  72. url = in1k_pretrained_urls["elannet_large"]
  73. if url is not None:
  74. print('Loading backbone pretrained weight from : {}'.format(url))
  75. # checkpoint state dict
  76. checkpoint = torch.hub.load_state_dict_from_url(
  77. url=url, map_location="cpu", check_hash=True)
  78. checkpoint_state_dict = checkpoint.pop("model")
  79. # model state dict
  80. model_state_dict = self.state_dict()
  81. # check
  82. for k in list(checkpoint_state_dict.keys()):
  83. if k in model_state_dict:
  84. shape_model = tuple(model_state_dict[k].shape)
  85. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  86. if shape_model != shape_checkpoint:
  87. checkpoint_state_dict.pop(k)
  88. else:
  89. checkpoint_state_dict.pop(k)
  90. print('Unused key: ', k)
  91. # load the weight
  92. self.load_state_dict(checkpoint_state_dict)
  93. else:
  94. print('No pretrained weight for model scale: {}.'.format(self.model_scale))
  95. def forward(self, x):
  96. c1 = self.layer_1(x)
  97. c2 = self.layer_2(c1)
  98. c3 = self.layer_3(c2)
  99. c4 = self.layer_4(c3)
  100. c5 = self.layer_5(c4)
  101. outputs = [c3, c4, c5]
  102. return outputs
  103. if __name__=='__main__':
  104. from thop import profile
  105. # Build backbone
  106. model = Yolov7Backbone(use_pretrained=True)
  107. # Randomly generate a input data
  108. x = torch.randn(2, 3, 640, 640)
  109. # Inference
  110. outputs = model(x)
  111. print(' - the shape of input : ', x.shape)
  112. for out in outputs:
  113. print(' - the shape of output : ', out.shape)
  114. x = torch.randn(1, 3, 640, 640)
  115. flops, params = profile(model, inputs=(x, ), verbose=False)
  116. print('============== FLOPs & Params ================')
  117. print(' - FLOPs : {:.2f} G'.format(flops / 1e9 * 2))
  118. print(' - Params : {:.2f} M'.format(params / 1e6))