yolov7_backbone.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolov7_basic import Conv, ELANBlock, DownSample
  5. except:
  6. from yolov7_basic import Conv, ELANBlock, DownSample
  7. model_urls = {
  8. "elannet_tiny": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/yolov7_elannet_tiny.pth",
  9. "elannet_large": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/yolov7_elannet_large.pth",
  10. "elannet_huge": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/yolov7_elannet_huge.pth",
  11. }
  12. # --------------------- ELANNet -----------------------
  13. ## ELANNet-Tiny
  14. class ELANNet_Tiny(nn.Module):
  15. """
  16. ELAN-Net of YOLOv7-Tiny.
  17. """
  18. def __init__(self, act_type='silu', norm_type='BN', depthwise=False):
  19. super(ELANNet_Tiny, self).__init__()
  20. # -------------- Basic parameters --------------
  21. self.feat_dims = [32, 64, 128, 256, 512]
  22. self.squeeze_ratios = [0.5, 0.5, 0.5, 0.5] # Stage-1 -> Stage-4
  23. self.branch_depths = [1, 1, 1, 1] # Stage-1 -> Stage-4
  24. # -------------- Network parameters --------------
  25. ## P1/2
  26. self.layer_1 = Conv(3, self.feat_dims[0], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  27. ## P2/4: Stage-1
  28. self.layer_2 = nn.Sequential(
  29. Conv(self.feat_dims[0], self.feat_dims[1], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  30. ELANBlock(self.feat_dims[1], self.feat_dims[1], self.squeeze_ratios[0], self.branch_depths[0], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  31. )
  32. ## P3/8: Stage-2
  33. self.layer_3 = nn.Sequential(
  34. nn.MaxPool2d((2, 2), 2),
  35. ELANBlock(self.feat_dims[1], self.feat_dims[2], self.squeeze_ratios[1], self.branch_depths[1], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  36. )
  37. ## P4/16: Stage-3
  38. self.layer_4 = nn.Sequential(
  39. nn.MaxPool2d((2, 2), 2),
  40. ELANBlock(self.feat_dims[2], self.feat_dims[3], self.squeeze_ratios[2], self.branch_depths[2], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  41. )
  42. ## P5/32: Stage-4
  43. self.layer_5 = nn.Sequential(
  44. nn.MaxPool2d((2, 2), 2),
  45. ELANBlock(self.feat_dims[3], self.feat_dims[4], self.squeeze_ratios[3], self.branch_depths[3], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  46. )
  47. def forward(self, x):
  48. c1 = self.layer_1(x)
  49. c2 = self.layer_2(c1)
  50. c3 = self.layer_3(c2)
  51. c4 = self.layer_4(c3)
  52. c5 = self.layer_5(c4)
  53. outputs = [c3, c4, c5]
  54. return outputs
  55. ## ELANNet-Large
  56. class ELANNet_Lagre(nn.Module):
  57. def __init__(self, act_type='silu', norm_type='BN', depthwise=False):
  58. super(ELANNet_Lagre, self).__init__()
  59. # -------------------- Basic parameters --------------------
  60. self.feat_dims = [32, 64, 128, 256, 512, 1024, 1024]
  61. self.squeeze_ratios = [0.5, 0.5, 0.5, 0.25] # Stage-1 -> Stage-4
  62. self.branch_depths = [2, 2, 2, 2] # Stage-1 -> Stage-4
  63. # -------------------- Network parameters --------------------
  64. ## P1/2
  65. self.layer_1 = nn.Sequential(
  66. Conv(3, self.feat_dims[0], k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  67. Conv(self.feat_dims[0], self.feat_dims[1], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  68. Conv(self.feat_dims[1], self.feat_dims[1], k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  69. )
  70. ## P2/4: Stage-1
  71. self.layer_2 = nn.Sequential(
  72. Conv(self.feat_dims[1], self.feat_dims[2], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  73. ELANBlock(self.feat_dims[2], self.feat_dims[3], self.squeeze_ratios[0], self.branch_depths[0], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  74. )
  75. ## P3/8: Stage-2
  76. self.layer_3 = nn.Sequential(
  77. DownSample(self.feat_dims[3], self.feat_dims[3], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  78. ELANBlock(self.feat_dims[3], self.feat_dims[4], self.squeeze_ratios[1], self.branch_depths[1], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  79. )
  80. ## P4/16: Stage-3
  81. self.layer_4 = nn.Sequential(
  82. DownSample(self.feat_dims[4], self.feat_dims[4], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  83. ELANBlock(self.feat_dims[4], self.feat_dims[5], self.squeeze_ratios[2], self.branch_depths[2], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  84. )
  85. ## P5/32: Stage-4
  86. self.layer_5 = nn.Sequential(
  87. DownSample(self.feat_dims[5], self.feat_dims[5], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  88. ELANBlock(self.feat_dims[5], self.feat_dims[6], self.squeeze_ratios[3], self.branch_depths[3], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  89. )
  90. def forward(self, x):
  91. c1 = self.layer_1(x)
  92. c2 = self.layer_2(c1)
  93. c3 = self.layer_3(c2)
  94. c4 = self.layer_4(c3)
  95. c5 = self.layer_5(c4)
  96. outputs = [c3, c4, c5]
  97. return outputs
  98. ## ELANNet-Huge
  99. class ELANNet_Huge(nn.Module):
  100. def __init__(self, act_type='silu', norm_type='BN', depthwise=False):
  101. super(ELANNet_Huge, self).__init__()
  102. # -------------------- Basic parameters --------------------
  103. self.feat_dims = [40, 80, 160, 320, 640, 1280, 1280]
  104. self.squeeze_ratios = [0.5, 0.5, 0.5, 0.25] # Stage-1 -> Stage-4
  105. self.branch_depths = [3, 3, 3, 3] # Stage-1 -> Stage-4
  106. # -------------------- Network parameters --------------------
  107. ## P1/2
  108. self.layer_1 = nn.Sequential(
  109. Conv(3, self.feat_dims[0], k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  110. Conv(self.feat_dims[0], self.feat_dims[1], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  111. Conv(self.feat_dims[1], self.feat_dims[1], k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  112. )
  113. ## P2/4: Stage-1
  114. self.layer_2 = nn.Sequential(
  115. Conv(self.feat_dims[1], self.feat_dims[2], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  116. ELANBlock(self.feat_dims[2], self.feat_dims[3], self.squeeze_ratios[0], self.branch_depths[0], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  117. )
  118. ## P3/8: Stage-2
  119. self.layer_3 = nn.Sequential(
  120. DownSample(self.feat_dims[3], self.feat_dims[3], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  121. ELANBlock(self.feat_dims[3], self.feat_dims[4], self.squeeze_ratios[1], self.branch_depths[1], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  122. )
  123. ## P4/16: Stage-3
  124. self.layer_4 = nn.Sequential(
  125. DownSample(self.feat_dims[4], self.feat_dims[4], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  126. ELANBlock(self.feat_dims[4], self.feat_dims[5], self.squeeze_ratios[2], self.branch_depths[2], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  127. )
  128. ## P5/32: Stage-4
  129. self.layer_5 = nn.Sequential(
  130. DownSample(self.feat_dims[5], self.feat_dims[5], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  131. ELANBlock(self.feat_dims[5], self.feat_dims[6], self.squeeze_ratios[3], self.branch_depths[3], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  132. )
  133. def forward(self, x):
  134. c1 = self.layer_1(x)
  135. c2 = self.layer_2(c1)
  136. c3 = self.layer_3(c2)
  137. c4 = self.layer_4(c3)
  138. c5 = self.layer_5(c4)
  139. outputs = [c3, c4, c5]
  140. return outputs
  141. # --------------------- Functions -----------------------
  142. ## build backbone
  143. def build_backbone(cfg, pretrained=False):
  144. # build backbone
  145. if cfg['backbone'] == 'elannet_huge':
  146. backbone = ELANNet_Huge(cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
  147. elif cfg['backbone'] == 'elannet_large':
  148. backbone = ELANNet_Lagre(cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
  149. elif cfg['backbone'] == 'elannet_tiny':
  150. backbone = ELANNet_Tiny(cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
  151. # pyramid feat dims
  152. feat_dims = backbone.feat_dims[-3:]
  153. # load imagenet pretrained weight
  154. if pretrained:
  155. url = model_urls[cfg['backbone']]
  156. if url is not None:
  157. print('Loading pretrained weight for {}.'.format(cfg['backbone'].upper()))
  158. checkpoint = torch.hub.load_state_dict_from_url(
  159. url=url, map_location="cpu", check_hash=True)
  160. # checkpoint state dict
  161. checkpoint_state_dict = checkpoint.pop("model")
  162. # model state dict
  163. model_state_dict = backbone.state_dict()
  164. # check
  165. for k in list(checkpoint_state_dict.keys()):
  166. if k in model_state_dict:
  167. shape_model = tuple(model_state_dict[k].shape)
  168. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  169. if shape_model != shape_checkpoint:
  170. checkpoint_state_dict.pop(k)
  171. else:
  172. checkpoint_state_dict.pop(k)
  173. print('Unused key: ', k)
  174. backbone.load_state_dict(checkpoint_state_dict)
  175. else:
  176. print('No backbone pretrained: ELANNet')
  177. return backbone, feat_dims
  178. if __name__ == '__main__':
  179. import time
  180. from thop import profile
  181. cfg = {
  182. 'pretrained': False,
  183. 'backbone': 'elannet_tiny',
  184. 'bk_act': 'silu',
  185. 'bk_norm': 'BN',
  186. 'bk_dpw': False,
  187. }
  188. model, feats = build_backbone(cfg)
  189. x = torch.randn(1, 3, 640, 640)
  190. t0 = time.time()
  191. outputs = model(x)
  192. t1 = time.time()
  193. print('Time: ', t1 - t0)
  194. for out in outputs:
  195. print(out.shape)
  196. print('==============================')
  197. flops, params = profile(model, inputs=(x, ), verbose=False)
  198. print('==============================')
  199. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  200. print('Params : {:.2f} M'.format(params / 1e6))