cnn_backbone.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .cnn_basic import Conv, ELANBlock, DownSample
  5. except:
  6. from cnn_basic import Conv, ELANBlock, DownSample
  7. model_urls = {
  8. 'elannet_pico': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_pico.pth",
  9. 'elannet_nano': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_nano.pth",
  10. 'elannet_small': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_small.pth",
  11. 'elannet_medium': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_medium.pth",
  12. 'elannet_large': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_large.pth",
  13. 'elannet_huge': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_huge.pth",
  14. }
  15. # ---------------------------- Backbones ----------------------------
  16. # ELANNet-P5
  17. class ELANNet(nn.Module):
  18. def __init__(self, width=1.0, depth=1.0, act_type='silu', norm_type='BN', depthwise=False):
  19. super(ELANNet, self).__init__()
  20. self.feat_dims = [int(512 * width), int(1024 * width), int(1024 * width)]
  21. # P1/2
  22. self.layer_1 = nn.Sequential(
  23. Conv(3, int(64*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  24. Conv(int(64*width), int(64*width), k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  25. )
  26. # P2/4
  27. self.layer_2 = nn.Sequential(
  28. Conv(int(64*width), int(128*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  29. ELANBlock(in_dim=int(128*width), out_dim=int(256*width), expand_ratio=0.5, depth=depth,
  30. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  31. )
  32. # P3/8
  33. self.layer_3 = nn.Sequential(
  34. DownSample(in_dim=int(256*width), out_dim=int(256*width), act_type=act_type, norm_type=norm_type),
  35. ELANBlock(in_dim=int(256*width), out_dim=int(512*width), expand_ratio=0.5, depth=depth,
  36. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  37. )
  38. # P4/16
  39. self.layer_4 = nn.Sequential(
  40. DownSample(in_dim=int(512*width), out_dim=int(512*width), act_type=act_type, norm_type=norm_type),
  41. ELANBlock(in_dim=int(512*width), out_dim=int(1024*width), expand_ratio=0.5, depth=depth,
  42. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  43. )
  44. # P5/32
  45. self.layer_5 = nn.Sequential(
  46. DownSample(in_dim=int(1024*width), out_dim=int(1024*width), act_type=act_type, norm_type=norm_type),
  47. ELANBlock(in_dim=int(1024*width), out_dim=int(1024*width), expand_ratio=0.25, depth=depth,
  48. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  49. )
  50. def forward(self, x):
  51. c1 = self.layer_1(x)
  52. c2 = self.layer_2(c1)
  53. c3 = self.layer_3(c2)
  54. c4 = self.layer_4(c3)
  55. c5 = self.layer_5(c4)
  56. outputs = [c3, c4, c5]
  57. return outputs
  58. # ---------------------------- Functions ----------------------------
  59. ## load pretrained weight
  60. def load_weight(model, model_name):
  61. # load weight
  62. print('Loading pretrained weight ...')
  63. url = model_urls[model_name]
  64. if url is not None:
  65. checkpoint = torch.hub.load_state_dict_from_url(
  66. url=url, map_location="cpu", check_hash=True)
  67. # checkpoint state dict
  68. checkpoint_state_dict = checkpoint.pop("model")
  69. # model state dict
  70. model_state_dict = model.state_dict()
  71. # check
  72. for k in list(checkpoint_state_dict.keys()):
  73. if k in model_state_dict:
  74. shape_model = tuple(model_state_dict[k].shape)
  75. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  76. if shape_model != shape_checkpoint:
  77. checkpoint_state_dict.pop(k)
  78. else:
  79. checkpoint_state_dict.pop(k)
  80. print(k)
  81. model.load_state_dict(checkpoint_state_dict)
  82. else:
  83. print('No pretrained for {}'.format(model_name))
  84. return model
  85. ## build ELAN-Net
  86. def build_backbone(cfg, pretrained=False):
  87. # model
  88. backbone = ELANNet(
  89. width=cfg['width'],
  90. depth=cfg['depth'],
  91. act_type=cfg['bk_act'],
  92. norm_type=cfg['bk_norm'],
  93. depthwise=cfg['bk_dpw']
  94. )
  95. # check whether to load imagenet pretrained weight
  96. if pretrained:
  97. if cfg['width'] == 0.25 and cfg['depth'] == 0.34 and cfg['bk_dpw']:
  98. backbone = load_weight(backbone, model_name='elannet_pico')
  99. elif cfg['width'] == 0.25 and cfg['depth'] == 0.34:
  100. backbone = load_weight(backbone, model_name='elannet_nano')
  101. elif cfg['width'] == 0.5 and cfg['depth'] == 0.34:
  102. backbone = load_weight(backbone, model_name='elannet_small')
  103. elif cfg['width'] == 0.75 and cfg['depth'] == 0.67:
  104. backbone = load_weight(backbone, model_name='elannet_medium')
  105. elif cfg['width'] == 1.0 and cfg['depth'] == 1.0:
  106. backbone = load_weight(backbone, model_name='elannet_large')
  107. elif cfg['width'] == 1.25 and cfg['depth'] == 1.34:
  108. backbone = load_weight(backbone, model_name='elannet_huge')
  109. feat_dims = backbone.feat_dims
  110. return backbone, feat_dims
  111. if __name__ == '__main__':
  112. import time
  113. from thop import profile
  114. cfg = {
  115. 'pretrained': True,
  116. 'bk_act': 'silu',
  117. 'bk_norm': 'BN',
  118. 'bk_dpw': True,
  119. 'width': 0.25,
  120. 'depth': 0.34,
  121. }
  122. model, feats = build_backbone(cfg)
  123. x = torch.randn(1, 3, 640, 640)
  124. t0 = time.time()
  125. outputs = model(x)
  126. t1 = time.time()
  127. print('Time: ', t1 - t0)
  128. for out in outputs:
  129. print(out.shape)
  130. print('==============================')
  131. flops, params = profile(model, inputs=(x, ), verbose=False)
  132. print('==============================')
  133. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  134. print('Params : {:.2f} M'.format(params / 1e6))