yolov3_backbone.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolov3_basic import Conv, ResBlock
  5. except:
  6. from yolov3_basic import Conv, ResBlock
  7. model_urls = {
  8. "darknet_tiny": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/darknet_tiny.pth",
  9. "darknet53": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/darknet53_silu.pth",
  10. }
  11. # --------------------- DarkNet-53 -----------------------
  12. ## DarkNet-53
  13. class DarkNet53(nn.Module):
  14. def __init__(self, act_type='silu', norm_type='BN'):
  15. super(DarkNet53, self).__init__()
  16. self.feat_dims = [256, 512, 1024]
  17. # P1
  18. self.layer_1 = nn.Sequential(
  19. Conv(3, 32, k=3, p=1, act_type=act_type, norm_type=norm_type),
  20. Conv(32, 64, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  21. ResBlock(64, 64, nblocks=1, act_type=act_type, norm_type=norm_type)
  22. )
  23. # P2
  24. self.layer_2 = nn.Sequential(
  25. Conv(64, 128, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  26. ResBlock(128, 128, nblocks=2, act_type=act_type, norm_type=norm_type)
  27. )
  28. # P3
  29. self.layer_3 = nn.Sequential(
  30. Conv(128, 256, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  31. ResBlock(256, 256, nblocks=8, act_type=act_type, norm_type=norm_type)
  32. )
  33. # P4
  34. self.layer_4 = nn.Sequential(
  35. Conv(256, 512, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  36. ResBlock(512, 512, nblocks=8, act_type=act_type, norm_type=norm_type)
  37. )
  38. # P5
  39. self.layer_5 = nn.Sequential(
  40. Conv(512, 1024, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  41. ResBlock(1024, 1024, nblocks=4, act_type=act_type, norm_type=norm_type)
  42. )
  43. def forward(self, x):
  44. c1 = self.layer_1(x)
  45. c2 = self.layer_2(c1)
  46. c3 = self.layer_3(c2)
  47. c4 = self.layer_4(c3)
  48. c5 = self.layer_5(c4)
  49. outputs = [c3, c4, c5]
  50. return outputs
  51. ## DarkNet-Tiny
  52. class DarkNetTiny(nn.Module):
  53. def __init__(self, act_type='silu', norm_type='BN'):
  54. super(DarkNetTiny, self).__init__()
  55. self.feat_dims = [64, 128, 256]
  56. # stride = 2
  57. self.layer_1 = nn.Sequential(
  58. Conv(3, 16, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  59. ResBlock(16, 16, nblocks=1, act_type=act_type, norm_type=norm_type)
  60. )
  61. # stride = 4
  62. self.layer_2 = nn.Sequential(
  63. Conv(16, 32, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  64. ResBlock(32, 32, nblocks=1, act_type=act_type, norm_type=norm_type)
  65. )
  66. # stride = 8
  67. self.layer_3 = nn.Sequential(
  68. Conv(32, 64, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  69. ResBlock(64, 64, nblocks=3, act_type=act_type, norm_type=norm_type)
  70. )
  71. # stride = 16
  72. self.layer_4 = nn.Sequential(
  73. Conv(64, 128, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  74. ResBlock(128, 128, nblocks=3, act_type=act_type, norm_type=norm_type)
  75. )
  76. # stride = 32
  77. self.layer_5 = nn.Sequential(
  78. Conv(128, 256, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  79. ResBlock(256, 256, nblocks=2, act_type=act_type, norm_type=norm_type)
  80. )
  81. def forward(self, x):
  82. c1 = self.layer_1(x)
  83. c2 = self.layer_2(c1)
  84. c3 = self.layer_3(c2)
  85. c4 = self.layer_4(c3)
  86. c5 = self.layer_5(c4)
  87. outputs = [c3, c4, c5]
  88. return outputs
  89. # --------------------- Functions -----------------------
  90. def build_backbone(model_name='darknet53', pretrained=False):
  91. """Constructs a darknet-53 model.
  92. Args:
  93. pretrained (bool): If True, returns a model pre-trained on ImageNet
  94. """
  95. if model_name == 'darknet53':
  96. backbone = DarkNet53(act_type='silu', norm_type='BN')
  97. feat_dims = backbone.feat_dims
  98. elif model_name == 'darknet_tiny':
  99. backbone = DarkNetTiny(act_type='silu', norm_type='BN')
  100. feat_dims = backbone.feat_dims
  101. if pretrained:
  102. url = model_urls[model_name]
  103. if url is not None:
  104. print('Loading pretrained weight ...')
  105. checkpoint = torch.hub.load_state_dict_from_url(
  106. url=url, map_location="cpu", check_hash=True)
  107. # checkpoint state dict
  108. checkpoint_state_dict = checkpoint.pop("model")
  109. # model state dict
  110. model_state_dict = backbone.state_dict()
  111. # check
  112. for k in list(checkpoint_state_dict.keys()):
  113. if k in model_state_dict:
  114. shape_model = tuple(model_state_dict[k].shape)
  115. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  116. if shape_model != shape_checkpoint:
  117. checkpoint_state_dict.pop(k)
  118. else:
  119. checkpoint_state_dict.pop(k)
  120. print('Unused key: ', k)
  121. backbone.load_state_dict(checkpoint_state_dict)
  122. else:
  123. print('No backbone pretrained: DarkNet53')
  124. return backbone, feat_dims
  125. if __name__ == '__main__':
  126. import time
  127. from thop import profile
  128. model, feats = build_backbone(model_name='darknet53', pretrained=True)
  129. x = torch.randn(1, 3, 640, 640)
  130. t0 = time.time()
  131. outputs = model(x)
  132. t1 = time.time()
  133. print('Time: ', t1 - t0)
  134. for out in outputs:
  135. print(out.shape)
  136. print('==============================')
  137. flops, params = profile(model, inputs=(x, ), verbose=False)
  138. print('==============================')
  139. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  140. print('Params : {:.2f} M'.format(params / 1e6))