yolov4_backbone.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolov4_basic import Conv, CSPBlock
  5. except:
  6. from yolov4_basic import Conv, CSPBlock
  7. model_urls = {
  8. "cspdarknet_tiny": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/cspdarknet_tiny.pth",
  9. "cspdarknet53": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/cspdarknet53_silu.pth",
  10. }
  11. # --------------------- CSPDarkNet-53 -----------------------
  12. ## CSPDarkNet-53
  13. class CSPDarkNet53(nn.Module):
  14. def __init__(self, act_type='silu', norm_type='BN'):
  15. super(CSPDarkNet53, self).__init__()
  16. self.feat_dims = [256, 512, 1024]
  17. # P1
  18. self.layer_1 = nn.Sequential(
  19. Conv(3, 32, k=3, p=1, act_type=act_type, norm_type=norm_type),
  20. Conv(32, 64, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  21. CSPBlock(64, 64, expand_ratio=0.5, nblocks=1, shortcut=True, act_type=act_type, norm_type=norm_type)
  22. )
  23. # P2
  24. self.layer_2 = nn.Sequential(
  25. Conv(64, 128, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  26. CSPBlock(128, 128, expand_ratio=0.5, nblocks=2, shortcut=True, act_type=act_type, norm_type=norm_type)
  27. )
  28. # P3
  29. self.layer_3 = nn.Sequential(
  30. Conv(128, 256, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  31. CSPBlock(256, 256, expand_ratio=0.5, nblocks=8, shortcut=True, act_type=act_type, norm_type=norm_type)
  32. )
  33. # P4
  34. self.layer_4 = nn.Sequential(
  35. Conv(256, 512, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  36. CSPBlock(512, 512, expand_ratio=0.5, nblocks=8, shortcut=True, act_type=act_type, norm_type=norm_type)
  37. )
  38. # P5
  39. self.layer_5 = nn.Sequential(
  40. Conv(512, 1024, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  41. CSPBlock(1024, 1024, expand_ratio=0.5, nblocks=4, shortcut=True, act_type=act_type, norm_type=norm_type)
  42. )
  43. def forward(self, x):
  44. c1 = self.layer_1(x)
  45. c2 = self.layer_2(c1)
  46. c3 = self.layer_3(c2)
  47. c4 = self.layer_4(c3)
  48. c5 = self.layer_5(c4)
  49. outputs = [c3, c4, c5]
  50. return outputs
  51. ## CSPDarkNet-Tiny
  52. class CSPDarkNetTiny(nn.Module):
  53. def __init__(self, act_type='silu', norm_type='BN'):
  54. super(CSPDarkNetTiny, self).__init__()
  55. self.feat_dims = [64, 128, 256]
  56. # stride = 2
  57. self.layer_1 = nn.Sequential(
  58. Conv(3, 16, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  59. CSPBlock(16, 16, expand_ratio=0.5, nblocks=1, shortcut=True, act_type=act_type, norm_type=norm_type)
  60. )
  61. # stride = 4
  62. self.layer_2 = nn.Sequential(
  63. Conv(16, 32, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  64. CSPBlock(32, 32, expand_ratio=0.5, nblocks=1, shortcut=True, act_type=act_type, norm_type=norm_type)
  65. )
  66. # stride = 8
  67. self.layer_3 = nn.Sequential(
  68. Conv(32, 64, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  69. CSPBlock(64, 64, expand_ratio=0.5, nblocks=3, shortcut=True, act_type=act_type, norm_type=norm_type)
  70. )
  71. # stride = 16
  72. self.layer_4 = nn.Sequential(
  73. Conv(64, 128, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  74. CSPBlock(128, 128, expand_ratio=0.5, nblocks=3, shortcut=True, act_type=act_type, norm_type=norm_type)
  75. )
  76. # stride = 32
  77. self.layer_5 = nn.Sequential(
  78. Conv(128, 256, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  79. CSPBlock(256, 256, expand_ratio=0.5, nblocks=2, shortcut=True, act_type=act_type, norm_type=norm_type)
  80. )
  81. def forward(self, x):
  82. c1 = self.layer_1(x)
  83. c2 = self.layer_2(c1)
  84. c3 = self.layer_3(c2)
  85. c4 = self.layer_4(c3)
  86. c5 = self.layer_5(c4)
  87. outputs = [c3, c4, c5]
  88. return outputs
  89. # --------------------- Functions -----------------------
  90. def build_backbone(model_name='cspdarknet53', pretrained=False):
  91. """Constructs a cspdarknet-53 model.
  92. Args:
  93. pretrained (bool): If True, returns a model pre-trained on ImageNet
  94. """
  95. if model_name == 'cspdarknet53':
  96. backbone = CSPDarkNet53(act_type='silu', norm_type='BN')
  97. feat_dims = backbone.feat_dims
  98. elif model_name == 'cspdarknet_tiny':
  99. backbone = CSPDarkNetTiny(act_type='silu', norm_type='BN')
  100. feat_dims = backbone.feat_dims
  101. if pretrained:
  102. url = model_urls[model_name]
  103. if url is not None:
  104. print('Loading pretrained weight ...')
  105. checkpoint = torch.hub.load_state_dict_from_url(
  106. url=url, map_location="cpu", check_hash=True)
  107. # checkpoint state dict
  108. checkpoint_state_dict = checkpoint.pop("model")
  109. # model state dict
  110. model_state_dict = backbone.state_dict()
  111. # check
  112. for k in list(checkpoint_state_dict.keys()):
  113. if k in model_state_dict:
  114. shape_model = tuple(model_state_dict[k].shape)
  115. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  116. if shape_model != shape_checkpoint:
  117. checkpoint_state_dict.pop(k)
  118. else:
  119. checkpoint_state_dict.pop(k)
  120. print('Unused key: ', k)
  121. backbone.load_state_dict(checkpoint_state_dict)
  122. else:
  123. print('No backbone pretrained: CSPDarkNet53')
  124. return backbone, feat_dims
  125. if __name__ == '__main__':
  126. import time
  127. from thop import profile
  128. model, feats = build_backbone(model_name='cspdarknet_tiny', pretrained=False)
  129. x = torch.randn(1, 3, 224, 224)
  130. t0 = time.time()
  131. outputs = model(x)
  132. t1 = time.time()
  133. print('Time: ', t1 - t0)
  134. for out in outputs:
  135. print(out.shape)
  136. x = torch.randn(1, 3, 224, 224)
  137. print('==============================')
  138. flops, params = profile(model, inputs=(x, ), verbose=False)
  139. print('==============================')
  140. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  141. print('Params : {:.2f} M'.format(params / 1e6))