yolov5_backbone.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolov5_basic import Conv, CSPBlock
  5. from .yolov5_neck import SPPF
  6. except:
  7. from yolov5_basic import Conv, CSPBlock
  8. from yolov5_neck import SPPF
  9. # ImageNet-1K pretrained weight
  10. model_urls = {
  11. "cspdarknet_n": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/cspdarknet_nano.pth",
  12. "cspdarknet_s": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/cspdarknet_small.pth",
  13. "cspdarknet_m": None, # For Medium-level, it is not necessary to load pretrained weight.
  14. "cspdarknet_l": None, # For Large-level, it is not necessary to load pretrained weight.
  15. "cspdarknet_x": None, # For Huge-level, it is not necessary to load pretrained weight.
  16. }
  17. # CSPDarkNet
  18. class CSPDarkNet(nn.Module):
  19. def __init__(self, depth=1.0, width=1.0, act_type='silu', norm_type='BN', depthwise=False):
  20. super(CSPDarkNet, self).__init__()
  21. self.feat_dims = [round(64 * width), round(128 * width), round(256 * width), round(512 * width), round(1024 * width)]
  22. # P1/2
  23. self.layer_1 = Conv(3, self.feat_dims[0], k=6, p=2, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  24. # P2/4
  25. self.layer_2 = nn.Sequential(
  26. Conv(self.feat_dims[0], self.feat_dims[1], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  27. CSPBlock(in_dim = self.feat_dims[1],
  28. out_dim = self.feat_dims[1],
  29. expand_ratio = 0.5,
  30. nblocks = round(3*depth),
  31. shortcut = True,
  32. act_type = act_type,
  33. norm_type = norm_type,
  34. depthwise = depthwise)
  35. )
  36. # P3/8
  37. self.layer_3 = nn.Sequential(
  38. Conv(self.feat_dims[1], self.feat_dims[2], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  39. CSPBlock(in_dim = self.feat_dims[2],
  40. out_dim = self.feat_dims[2],
  41. expand_ratio = 0.5,
  42. nblocks = round(9*depth),
  43. shortcut = True,
  44. act_type = act_type,
  45. norm_type = norm_type,
  46. depthwise = depthwise)
  47. )
  48. # P4/16
  49. self.layer_4 = nn.Sequential(
  50. Conv(self.feat_dims[2], self.feat_dims[3], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  51. CSPBlock(in_dim = self.feat_dims[3],
  52. out_dim = self.feat_dims[3],
  53. expand_ratio = 0.5,
  54. nblocks = round(9*depth),
  55. shortcut = True,
  56. act_type = act_type,
  57. norm_type = norm_type,
  58. depthwise = depthwise)
  59. )
  60. # P5/32
  61. self.layer_5 = nn.Sequential(
  62. Conv(self.feat_dims[3], self.feat_dims[4], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  63. SPPF(self.feat_dims[4], self.feat_dims[4], expand_ratio=0.5),
  64. CSPBlock(in_dim = self.feat_dims[4],
  65. out_dim = self.feat_dims[4],
  66. expand_ratio = 0.5,
  67. nblocks = round(3*depth),
  68. shortcut = True,
  69. act_type = act_type,
  70. norm_type = norm_type,
  71. depthwise = depthwise)
  72. )
  73. def forward(self, x):
  74. c1 = self.layer_1(x)
  75. c2 = self.layer_2(c1)
  76. c3 = self.layer_3(c2)
  77. c4 = self.layer_4(c3)
  78. c5 = self.layer_5(c4)
  79. outputs = [c3, c4, c5]
  80. return outputs
  81. # ---------------------------- Functions ----------------------------
  82. ## load pretrained weight
  83. def load_weight(model, model_name):
  84. # load weight
  85. print('Loading pretrained weight ...')
  86. url = model_urls[model_name]
  87. if url is not None:
  88. checkpoint = torch.hub.load_state_dict_from_url(
  89. url=url, map_location="cpu", check_hash=True)
  90. # checkpoint state dict
  91. checkpoint_state_dict = checkpoint.pop("model")
  92. # model state dict
  93. model_state_dict = model.state_dict()
  94. # check
  95. for k in list(checkpoint_state_dict.keys()):
  96. if k in model_state_dict:
  97. shape_model = tuple(model_state_dict[k].shape)
  98. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  99. if shape_model != shape_checkpoint:
  100. checkpoint_state_dict.pop(k)
  101. else:
  102. checkpoint_state_dict.pop(k)
  103. print('Unused key: ', k)
  104. model.load_state_dict(checkpoint_state_dict)
  105. else:
  106. print('No pretrained for {}'.format(model_name))
  107. return model
  108. ## build CSPDarkNet
  109. def build_backbone(cfg, pretrained=False):
  110. # Build backbone
  111. backbone = CSPDarkNet(cfg['depth'], cfg['width'], cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
  112. feat_dims = backbone.feat_dims[-3:]
  113. # Load pretrained weight
  114. if pretrained:
  115. if cfg['width'] == 0.25 and cfg['depth'] == 0.34:
  116. backbone = load_weight(backbone, model_name='cspdarknet_n')
  117. elif cfg['width'] == 0.375 and cfg['depth'] == 0.34:
  118. backbone = load_weight(backbone, model_name='cspdarknet_t')
  119. elif cfg['width'] == 0.5 and cfg['depth'] == 0.34:
  120. backbone = load_weight(backbone, model_name='cspdarknet_s')
  121. elif cfg['width'] == 0.75 and cfg['depth'] == 0.67:
  122. backbone = load_weight(backbone, model_name='cspdarknet_m')
  123. elif cfg['width'] == 1.0 and cfg['depth'] == 1.0:
  124. backbone = load_weight(backbone, model_name='cspdarknet_l')
  125. elif cfg['width'] == 1.25 and cfg['depth'] == 1.34:
  126. backbone = load_weight(backbone, model_name='cspdarknet_x')
  127. return backbone, feat_dims
  128. if __name__ == '__main__':
  129. import time
  130. from thop import profile
  131. cfg = {
  132. 'bk_pretrained': True,
  133. 'bk_act': 'silu',
  134. 'bk_norm': 'BN',
  135. 'bk_dpw': False,
  136. 'p6_feat': False,
  137. 'p7_feat': False,
  138. 'width': 0.50,
  139. 'depth': 0.34,
  140. }
  141. model, feats = build_backbone(cfg, pretrained=cfg['bk_pretrained'])
  142. x = torch.randn(1, 3, 224, 224)
  143. t0 = time.time()
  144. outputs = model(x)
  145. t1 = time.time()
  146. print('Time: ', t1 - t0)
  147. for out in outputs:
  148. print(out.shape)
  149. x = torch.randn(1, 3, 224, 224)
  150. print('==============================')
  151. flops, params = profile(model, inputs=(x, ), verbose=False)
  152. print('==============================')
  153. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  154. print('Params : {:.2f} M'.format(params / 1e6))