yolov7_backbone.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolov7_basic import Conv, ELANBlock, DownSample
  5. except:
  6. from yolov7_basic import Conv, ELANBlock, DownSample
  7. model_urls = {
  8. "elannet_nano": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/yolov7_elannet_nano.pth",
  9. "elannet_tiny": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/yolov7_elannet_tiny.pth",
  10. "elannet_large": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/yolov7_elannet_large.pth",
  11. "elannet_huge": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/yolov7_elannet_huge.pth",
  12. }
  13. # --------------------- ELANNet -----------------------
  14. # ELANNet-Nano
  15. class ELANNet_Nano(nn.Module):
  16. def __init__(self, act_type='lrelu', norm_type='BN', depthwise=True):
  17. super(ELANNet_Nano, self).__init__()
  18. self.feat_dims = [64, 128, 256]
  19. # P1/2
  20. self.layer_1 = Conv(3, 16, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  21. # P2/4
  22. self.layer_2 = nn.Sequential(
  23. Conv(16, 32, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  24. ELANBlock(in_dim=32, out_dim=32, expand_ratio=0.5, depth=1,
  25. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  26. )
  27. # P3/8
  28. self.layer_3 = nn.Sequential(
  29. nn.MaxPool2d((2, 2), 2),
  30. ELANBlock(in_dim=32, out_dim=64, expand_ratio=0.5, depth=1,
  31. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  32. )
  33. # P4/16
  34. self.layer_4 = nn.Sequential(
  35. nn.MaxPool2d((2, 2), 2),
  36. ELANBlock(in_dim=64, out_dim=128, expand_ratio=0.5, depth=1,
  37. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  38. )
  39. # P5/32
  40. self.layer_5 = nn.Sequential(
  41. nn.MaxPool2d((2, 2), 2),
  42. ELANBlock(in_dim=128, out_dim=256, expand_ratio=0.5, depth=1,
  43. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  44. )
  45. def forward(self, x):
  46. c1 = self.layer_1(x)
  47. c2 = self.layer_2(c1)
  48. c3 = self.layer_3(c2)
  49. c4 = self.layer_4(c3)
  50. c5 = self.layer_5(c4)
  51. outputs = [c3, c4, c5]
  52. return outputs
  53. # ELANNet-Tiny
  54. class ELANNet_Tiny(nn.Module):
  55. """
  56. ELAN-Net of YOLOv7-Tiny.
  57. """
  58. def __init__(self, act_type='silu', norm_type='BN', depthwise=False):
  59. super(ELANNet_Tiny, self).__init__()
  60. self.feat_dims = [128, 256, 512]
  61. # P1/2
  62. self.layer_1 = Conv(3, 32, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  63. # P2/4
  64. self.layer_2 = nn.Sequential(
  65. Conv(32, 64, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  66. ELANBlock(in_dim=64, out_dim=64, expand_ratio=0.5, depth=1,
  67. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  68. )
  69. # P3/8
  70. self.layer_3 = nn.Sequential(
  71. nn.MaxPool2d((2, 2), 2),
  72. ELANBlock(in_dim=64, out_dim=128, expand_ratio=0.5, depth=1,
  73. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  74. )
  75. # P4/16
  76. self.layer_4 = nn.Sequential(
  77. nn.MaxPool2d((2, 2), 2),
  78. ELANBlock(in_dim=128, out_dim=256, expand_ratio=0.5, depth=1,
  79. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  80. )
  81. # P5/32
  82. self.layer_5 = nn.Sequential(
  83. nn.MaxPool2d((2, 2), 2),
  84. ELANBlock(in_dim=256, out_dim=512, expand_ratio=0.5, depth=1,
  85. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  86. )
  87. def forward(self, x):
  88. c1 = self.layer_1(x)
  89. c2 = self.layer_2(c1)
  90. c3 = self.layer_3(c2)
  91. c4 = self.layer_4(c3)
  92. c5 = self.layer_5(c4)
  93. outputs = [c3, c4, c5]
  94. return outputs
  95. ## ELANNet-Large
  96. class ELANNet_Lagre(nn.Module):
  97. def __init__(self, act_type='silu', norm_type='BN', depthwise=False):
  98. super(ELANNet_Lagre, self).__init__()
  99. self.feat_dims = [512, 1024, 1024]
  100. # P1/2
  101. self.layer_1 = nn.Sequential(
  102. Conv(3, 32, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  103. Conv(32, 64, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  104. Conv(64, 64, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  105. )
  106. # P2/4
  107. self.layer_2 = nn.Sequential(
  108. Conv(64, 128, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  109. ELANBlock(in_dim=128, out_dim=256, expand_ratio=0.5, depth=2,
  110. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  111. )
  112. # P3/8
  113. self.layer_3 = nn.Sequential(
  114. DownSample(in_dim=256, out_dim=256, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  115. ELANBlock(in_dim=256, out_dim=512, expand_ratio=0.5, depth=2,
  116. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  117. )
  118. # P4/16
  119. self.layer_4 = nn.Sequential(
  120. DownSample(in_dim=512, out_dim=512, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  121. ELANBlock(in_dim=512, out_dim=1024, expand_ratio=0.5, depth=2,
  122. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  123. )
  124. # P5/32
  125. self.layer_5 = nn.Sequential(
  126. DownSample(in_dim=1024, out_dim=1024, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  127. ELANBlock(in_dim=1024, out_dim=1024, expand_ratio=0.25, depth=2,
  128. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  129. )
  130. def forward(self, x):
  131. c1 = self.layer_1(x)
  132. c2 = self.layer_2(c1)
  133. c3 = self.layer_3(c2)
  134. c4 = self.layer_4(c3)
  135. c5 = self.layer_5(c4)
  136. outputs = [c3, c4, c5]
  137. return outputs
  138. ## ELANNet-Huge
  139. class ELANNet_Huge(nn.Module):
  140. def __init__(self, act_type='silu', norm_type='BN', depthwise=False):
  141. super(ELANNet_Huge, self).__init__()
  142. self.feat_dims = [640, 1280, 1280]
  143. # P1/2
  144. self.layer_1 = nn.Sequential(
  145. Conv(3, 40, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  146. Conv(40, 80, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  147. Conv(80, 80, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  148. )
  149. # P2/4
  150. self.layer_2 = nn.Sequential(
  151. Conv(80, 160, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  152. ELANBlock(in_dim=160, out_dim=320, expand_ratio=0.5, depth=3,
  153. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  154. )
  155. # P3/8
  156. self.layer_3 = nn.Sequential(
  157. DownSample(in_dim=320, out_dim=320, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  158. ELANBlock(in_dim=320, out_dim=640, expand_ratio=0.5, depth=3,
  159. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  160. )
  161. # P4/16
  162. self.layer_4 = nn.Sequential(
  163. DownSample(in_dim=640, out_dim=640, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  164. ELANBlock(in_dim=640, out_dim=1280, expand_ratio=0.5, depth=3,
  165. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  166. )
  167. # P5/32
  168. self.layer_5 = nn.Sequential(
  169. DownSample(in_dim=1280, out_dim=1280, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  170. ELANBlock(in_dim=1280, out_dim=1280, expand_ratio=0.25, depth=3,
  171. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  172. )
  173. def forward(self, x):
  174. c1 = self.layer_1(x)
  175. c2 = self.layer_2(c1)
  176. c3 = self.layer_3(c2)
  177. c4 = self.layer_4(c3)
  178. c5 = self.layer_5(c4)
  179. outputs = [c3, c4, c5]
  180. return outputs
  181. # --------------------- Functions -----------------------
  182. def build_backbone(cfg, pretrained=False):
  183. """Constructs a ELANNet model.
  184. Args:
  185. pretrained (bool): If True, returns a model pre-trained on ImageNet
  186. """
  187. # build backbone
  188. if cfg['backbone'] == 'elannet_huge':
  189. backbone = ELANNet_Huge(cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
  190. elif cfg['backbone'] == 'elannet_large':
  191. backbone = ELANNet_Lagre(cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
  192. elif cfg['backbone'] == 'elannet_tiny':
  193. backbone = ELANNet_Tiny(cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
  194. elif cfg['backbone'] == 'elannet_nano':
  195. backbone = ELANNet_Nano(cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
  196. # pyramid feat dims
  197. feat_dims = backbone.feat_dims
  198. # load imagenet pretrained weight
  199. if pretrained:
  200. url = model_urls[cfg['backbone']]
  201. if url is not None:
  202. print('Loading pretrained weight for {}.'.format(cfg['backbone'].upper()))
  203. checkpoint = torch.hub.load_state_dict_from_url(
  204. url=url, map_location="cpu", check_hash=True)
  205. # checkpoint state dict
  206. checkpoint_state_dict = checkpoint.pop("model")
  207. # model state dict
  208. model_state_dict = backbone.state_dict()
  209. # check
  210. for k in list(checkpoint_state_dict.keys()):
  211. if k in model_state_dict:
  212. shape_model = tuple(model_state_dict[k].shape)
  213. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  214. if shape_model != shape_checkpoint:
  215. checkpoint_state_dict.pop(k)
  216. else:
  217. checkpoint_state_dict.pop(k)
  218. print(k)
  219. backbone.load_state_dict(checkpoint_state_dict)
  220. else:
  221. print('No backbone pretrained: ELANNet')
  222. return backbone, feat_dims
  223. if __name__ == '__main__':
  224. import time
  225. from thop import profile
  226. cfg = {
  227. 'pretrained': False,
  228. 'backbone': 'elannet_huge',
  229. 'bk_act': 'silu',
  230. 'bk_norm': 'BN',
  231. 'bk_dpw': False,
  232. 'p6_feat': False,
  233. 'p7_feat': False,
  234. }
  235. model, feats = build_backbone(cfg)
  236. x = torch.randn(1, 3, 224, 224)
  237. t0 = time.time()
  238. outputs = model(x)
  239. t1 = time.time()
  240. print('Time: ', t1 - t0)
  241. for out in outputs:
  242. print(out.shape)
  243. print('==============================')
  244. flops, params = profile(model, inputs=(x, ), verbose=False)
  245. print('==============================')
  246. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  247. print('Params : {:.2f} M'.format(params / 1e6))