yolov7_backbone.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import ConvModule, ELANBlock, DownSample
  5. except:
  6. from modules import ConvModule, ELANBlock, DownSample
  7. in1k_pretrained_urls = {
  8. "elannet_large": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/yolov7_elannet_large.pth",
  9. }
  10. # --------------------- Yolov7 backbone (CSPDarkNet-53 with SiLU) -----------------------
  11. class Yolov7Backbone(nn.Module):
  12. def __init__(self, use_pretrained: bool = False):
  13. super(Yolov7Backbone, self).__init__()
  14. self.feat_dims = [32, 64, 128, 256, 512, 1024, 1024]
  15. self.squeeze_ratios = [0.5, 0.5, 0.5, 0.25] # Stage-1 -> Stage-4
  16. self.branch_depths = [2, 2, 2, 2] # Stage-1 -> Stage-4
  17. self.use_pretrained = use_pretrained
  18. # -------------------- Network parameters --------------------
  19. ## P1/2
  20. self.layer_1 = nn.Sequential(
  21. ConvModule(3, self.feat_dims[0], kernel_size=3),
  22. ConvModule(self.feat_dims[0], self.feat_dims[1], kernel_size=3, stride=2),
  23. ConvModule(self.feat_dims[1], self.feat_dims[1], kernel_size=3)
  24. )
  25. ## P2/4: Stage-1
  26. self.layer_2 = nn.Sequential(
  27. ConvModule(self.feat_dims[1], self.feat_dims[2], kernel_size=3, stride=2),
  28. ELANBlock(self.feat_dims[2], self.feat_dims[3], self.squeeze_ratios[0], self.branch_depths[0])
  29. )
  30. ## P3/8: Stage-2
  31. self.layer_3 = nn.Sequential(
  32. DownSample(self.feat_dims[3], self.feat_dims[3]),
  33. ELANBlock(self.feat_dims[3], self.feat_dims[4], self.squeeze_ratios[1], self.branch_depths[1])
  34. )
  35. ## P4/16: Stage-3
  36. self.layer_4 = nn.Sequential(
  37. DownSample(self.feat_dims[4], self.feat_dims[4]),
  38. ELANBlock(self.feat_dims[4], self.feat_dims[5], self.squeeze_ratios[2], self.branch_depths[2])
  39. )
  40. ## P5/32: Stage-4
  41. self.layer_5 = nn.Sequential(
  42. DownSample(self.feat_dims[5], self.feat_dims[5]),
  43. ELANBlock(self.feat_dims[5], self.feat_dims[6], self.squeeze_ratios[3], self.branch_depths[3])
  44. )
  45. # Initialize all layers
  46. self.init_weights()
  47. def init_weights(self):
  48. """Initialize the parameters."""
  49. for m in self.modules():
  50. if isinstance(m, torch.nn.Conv2d):
  51. m.reset_parameters()
  52. # Load imagenet pretrained weight
  53. if self.use_pretrained:
  54. self.load_pretrained()
  55. def load_pretrained(self):
  56. url = in1k_pretrained_urls["elannet_large"]
  57. if url is not None:
  58. print('Loading backbone pretrained weight from : {}'.format(url))
  59. # checkpoint state dict
  60. checkpoint = torch.hub.load_state_dict_from_url(
  61. url=url, map_location="cpu", check_hash=True)
  62. checkpoint_state_dict = checkpoint.pop("model")
  63. # model state dict
  64. model_state_dict = self.state_dict()
  65. # check
  66. for k in list(checkpoint_state_dict.keys()):
  67. if k in model_state_dict:
  68. shape_model = tuple(model_state_dict[k].shape)
  69. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  70. if shape_model != shape_checkpoint:
  71. checkpoint_state_dict.pop(k)
  72. else:
  73. checkpoint_state_dict.pop(k)
  74. print('Unused key: ', k)
  75. # load the weight
  76. self.load_state_dict(checkpoint_state_dict)
  77. else:
  78. print('No pretrained weight for model scale: {}.'.format(self.model_scale))
  79. def forward(self, x):
  80. c1 = self.layer_1(x)
  81. c2 = self.layer_2(c1)
  82. c3 = self.layer_3(c2)
  83. c4 = self.layer_4(c3)
  84. c5 = self.layer_5(c4)
  85. outputs = [c3, c4, c5]
  86. return outputs
  87. if __name__=='__main__':
  88. from thop import profile
  89. # Build backbone
  90. model = Yolov7Backbone(use_pretrained=True)
  91. # Randomly generate a input data
  92. x = torch.randn(2, 3, 640, 640)
  93. # Inference
  94. outputs = model(x)
  95. print(' - the shape of input : ', x.shape)
  96. for out in outputs:
  97. print(' - the shape of output : ', out.shape)
  98. x = torch.randn(1, 3, 640, 640)
  99. flops, params = profile(model, inputs=(x, ), verbose=False)
  100. print('============== FLOPs & Params ================')
  101. print(' - FLOPs : {:.2f} G'.format(flops / 1e9 * 2))
  102. print(' - Params : {:.2f} M'.format(params / 1e6))