yolov7_af_backbone.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolov7_af_basic import BasicConv, MDown, ELANLayer
  5. except:
  6. from yolov7_af_basic import BasicConv, MDown, ELANLayer
  7. # ELANNet
  8. class Yolov7Backbone(nn.Module):
  9. def __init__(self, cfg):
  10. super(Yolov7Backbone, self).__init__()
  11. # ---------------- Basic parameters ----------------
  12. self.model_scale = cfg.scale
  13. self.bk_act = cfg.bk_act
  14. self.bk_norm = cfg.bk_norm
  15. self.bk_depthwise = cfg.bk_depthwise
  16. if self.model_scale in ["l", "x"]:
  17. self.elan_depth = 2
  18. self.feat_dims = [round(64 * cfg.width), round(128 * cfg.width), round(256 * cfg.width),
  19. round(512 * cfg.width), round(1024 * cfg.width), round(1024 * cfg.width)]
  20. self.last_stage_eratio = 0.25
  21. if self.model_scale in ["t"]:
  22. self.elan_depth = 1
  23. self.feat_dims = [round(64 * cfg.width), round(128 * cfg.width),
  24. round(256 * cfg.width), round(512 * cfg.width), round(1024 * cfg.width)]
  25. self.last_stage_eratio = 0.5
  26. # ---------------- Model parameters ----------------
  27. self.layer_1 = self.make_stem(3, self.feat_dims[0])
  28. self.layer_2 = self.make_block(self.feat_dims[0], self.feat_dims[1], expansion=0.5)
  29. self.layer_3 = self.make_block(self.feat_dims[1], self.feat_dims[2], expansion=0.5)
  30. self.layer_4 = self.make_block(self.feat_dims[2], self.feat_dims[3], expansion=0.5)
  31. self.layer_5 = self.make_block(self.feat_dims[3], self.feat_dims[4], expansion=self.last_stage_eratio)
  32. # Initialize all layers
  33. self.init_weights()
  34. def init_weights(self):
  35. """Initialize the parameters."""
  36. for m in self.modules():
  37. if isinstance(m, torch.nn.Conv2d):
  38. # In order to be consistent with the source code,
  39. # reset the Conv2d initialization parameters
  40. m.reset_parameters()
  41. def make_stem(self, in_dim, out_dim):
  42. if self.model_scale in ["l", "x"]:
  43. stem = nn.Sequential(
  44. BasicConv(in_dim, out_dim//2, kernel_size=3, padding=1, stride=1,
  45. act_type=self.bk_act, norm_type=self.bk_norm, depthwise=self.bk_depthwise),
  46. BasicConv(out_dim//2, out_dim, kernel_size=3, padding=1, stride=2,
  47. act_type=self.bk_act, norm_type=self.bk_norm, depthwise=self.bk_depthwise),
  48. BasicConv(out_dim, out_dim, kernel_size=3, padding=1, stride=1,
  49. act_type=self.bk_act, norm_type=self.bk_norm, depthwise=self.bk_depthwise)
  50. )
  51. elif self.model_scale in ["t"]:
  52. stem = BasicConv(in_dim, out_dim, kernel_size=6, padding=2, stride=2,
  53. act_type=self.bk_act, norm_type=self.bk_norm, depthwise=self.bk_depthwise)
  54. else:
  55. raise NotImplementedError("Unknown model scale: {}".format(self.model_scale))
  56. return stem
  57. def make_block(self, in_dim, out_dim, expansion=0.5):
  58. if self.model_scale in ["l", "x"]:
  59. block = nn.Sequential(
  60. MDown(in_dim, out_dim,
  61. act_type=self.bk_act, norm_type=self.bk_norm, depthwise=self.bk_depthwise),
  62. ELANLayer(out_dim, out_dim,
  63. expansion=expansion, num_blocks=self.elan_depth,
  64. act_type=self.bk_act, norm_type=self.bk_norm, depthwise=self.bk_depthwise),
  65. )
  66. elif self.model_scale in ["t"]:
  67. block = nn.Sequential(
  68. nn.MaxPool2d((2, 2), stride=2),
  69. ELANLayer(in_dim, out_dim,
  70. expansion=expansion, num_blocks=self.elan_depth,
  71. act_type=self.bk_act, norm_type=self.bk_norm, depthwise=self.bk_depthwise),
  72. )
  73. else:
  74. raise NotImplementedError("Unknown model scale: {}".format(self.model_scale))
  75. return block
  76. def forward(self, x):
  77. c1 = self.layer_1(x)
  78. c2 = self.layer_2(c1)
  79. c3 = self.layer_3(c2)
  80. c4 = self.layer_4(c3)
  81. c5 = self.layer_5(c4)
  82. outputs = [c3, c4, c5]
  83. return outputs
  84. if __name__ == '__main__':
  85. import time
  86. from thop import profile
  87. class BaseConfig(object):
  88. def __init__(self) -> None:
  89. self.bk_act = 'silu'
  90. self.bk_norm = 'BN'
  91. self.bk_depthwise = False
  92. self.width = 0.5
  93. self.depth = 0.34
  94. self.scale = "t"
  95. cfg = BaseConfig()
  96. model = Yolov7Backbone(cfg)
  97. x = torch.randn(1, 3, 640, 640)
  98. t0 = time.time()
  99. outputs = model(x)
  100. t1 = time.time()
  101. print('Time: ', t1 - t0)
  102. for out in outputs:
  103. print(out.shape)
  104. x = torch.randn(1, 3, 640, 640)
  105. print('==============================')
  106. flops, params = profile(model, inputs=(x, ), verbose=False)
  107. print('==============================')
  108. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  109. print('Params : {:.2f} M'.format(params / 1e6))