yolov5_af_backbone.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolov5_af_basic import BasicConv, CSPBlock
  5. except:
  6. from yolov5_af_basic import BasicConv, CSPBlock
  7. # --------------------- Yolov3's Backbone -----------------------
  8. ## Modified DarkNet
  9. class Yolov5Backbone(nn.Module):
  10. def __init__(self, cfg):
  11. super(Yolov5Backbone, self).__init__()
  12. # ------------------ Basic setting ------------------
  13. self.model_scale = cfg.scale
  14. self.feat_dims = [round(64 * cfg.width),
  15. round(128 * cfg.width),
  16. round(256 * cfg.width),
  17. round(512 * cfg.width),
  18. round(1024 * cfg.width)]
  19. # ------------------ Network setting ------------------
  20. ## P1/2
  21. self.layer_1 = BasicConv(3, self.feat_dims[0],
  22. kernel_size=6, padding=2, stride=2,
  23. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise)
  24. # P2/4
  25. self.layer_2 = nn.Sequential(
  26. BasicConv(self.feat_dims[0], self.feat_dims[1],
  27. kernel_size=3, padding=1, stride=2,
  28. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  29. CSPBlock(in_dim = self.feat_dims[1],
  30. out_dim = self.feat_dims[1],
  31. num_blocks = round(3*cfg.depth),
  32. expansion = 0.5,
  33. shortcut = True,
  34. act_type = cfg.bk_act,
  35. norm_type = cfg.bk_norm,
  36. depthwise = cfg.bk_depthwise)
  37. )
  38. # P3/8
  39. self.layer_3 = nn.Sequential(
  40. BasicConv(self.feat_dims[1], self.feat_dims[2],
  41. kernel_size=3, padding=1, stride=2,
  42. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  43. CSPBlock(in_dim = self.feat_dims[2],
  44. out_dim = self.feat_dims[2],
  45. num_blocks = round(9*cfg.depth),
  46. expansion = 0.5,
  47. shortcut = True,
  48. act_type = cfg.bk_act,
  49. norm_type = cfg.bk_norm,
  50. depthwise = cfg.bk_depthwise)
  51. )
  52. # P4/16
  53. self.layer_4 = nn.Sequential(
  54. BasicConv(self.feat_dims[2], self.feat_dims[3],
  55. kernel_size=3, padding=1, stride=2,
  56. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  57. CSPBlock(in_dim = self.feat_dims[3],
  58. out_dim = self.feat_dims[3],
  59. num_blocks = round(9*cfg.depth),
  60. expansion = 0.5,
  61. shortcut = True,
  62. act_type = cfg.bk_act,
  63. norm_type = cfg.bk_norm,
  64. depthwise = cfg.bk_depthwise)
  65. )
  66. # P5/32
  67. self.layer_5 = nn.Sequential(
  68. BasicConv(self.feat_dims[3], self.feat_dims[4],
  69. kernel_size=3, padding=1, stride=2,
  70. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  71. CSPBlock(in_dim = self.feat_dims[4],
  72. out_dim = self.feat_dims[4],
  73. num_blocks = round(3*cfg.depth),
  74. expansion = 0.5,
  75. shortcut = True,
  76. act_type = cfg.bk_act,
  77. norm_type = cfg.bk_norm,
  78. depthwise = cfg.bk_depthwise)
  79. )
  80. # Initialize all layers
  81. self.init_weights()
  82. def init_weights(self):
  83. """Initialize the parameters."""
  84. for m in self.modules():
  85. if isinstance(m, torch.nn.Conv2d):
  86. # In order to be consistent with the source code,
  87. # reset the Conv2d initialization parameters
  88. m.reset_parameters()
  89. def forward(self, x):
  90. c1 = self.layer_1(x)
  91. c2 = self.layer_2(c1)
  92. c3 = self.layer_3(c2)
  93. c4 = self.layer_4(c3)
  94. c5 = self.layer_5(c4)
  95. outputs = [c3, c4, c5]
  96. return outputs
  97. if __name__ == '__main__':
  98. import time
  99. from thop import profile
  100. class BaseConfig(object):
  101. def __init__(self) -> None:
  102. self.bk_act = 'silu'
  103. self.bk_norm = 'BN'
  104. self.bk_depthwise = False
  105. self.width = 0.5
  106. self.depth = 0.34
  107. self.scale = "s"
  108. self.use_pretrained = True
  109. cfg = BaseConfig()
  110. model = Yolov5Backbone(cfg)
  111. x = torch.randn(1, 3, 640, 640)
  112. t0 = time.time()
  113. outputs = model(x)
  114. t1 = time.time()
  115. print('Time: ', t1 - t0)
  116. for out in outputs:
  117. print(out.shape)
  118. x = torch.randn(1, 3, 640, 640)
  119. print('==============================')
  120. flops, params = profile(model, inputs=(x, ), verbose=False)
  121. print('==============================')
  122. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  123. print('Params : {:.2f} M'.format(params / 1e6))