yolov3_backbone.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolov3_basic import Conv, ResBlock
  5. except:
  6. from yolov3_basic import Conv, ResBlock
  7. model_urls = {
  8. "darknet53": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/darknet53_silu.pth",
  9. }
  10. # --------------------- DarkNet-53 -----------------------
  11. class DarkNet53(nn.Module):
  12. def __init__(self, act_type='silu', norm_type='BN'):
  13. super(DarkNet53, self).__init__()
  14. self.feat_dims = [256, 512, 1024]
  15. # P1
  16. self.layer_1 = nn.Sequential(
  17. Conv(3, 32, k=3, p=1, act_type=act_type, norm_type=norm_type),
  18. Conv(32, 64, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  19. ResBlock(64, 64, nblocks=1, act_type=act_type, norm_type=norm_type)
  20. )
  21. # P2
  22. self.layer_2 = nn.Sequential(
  23. Conv(64, 128, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  24. ResBlock(128, 128, nblocks=2, act_type=act_type, norm_type=norm_type)
  25. )
  26. # P3
  27. self.layer_3 = nn.Sequential(
  28. Conv(128, 256, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  29. ResBlock(256, 256, nblocks=8, act_type=act_type, norm_type=norm_type)
  30. )
  31. # P4
  32. self.layer_4 = nn.Sequential(
  33. Conv(256, 512, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  34. ResBlock(512, 512, nblocks=8, act_type=act_type, norm_type=norm_type)
  35. )
  36. # P5
  37. self.layer_5 = nn.Sequential(
  38. Conv(512, 1024, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  39. ResBlock(1024, 1024, nblocks=4, act_type=act_type, norm_type=norm_type)
  40. )
  41. def forward(self, x):
  42. c1 = self.layer_1(x)
  43. c2 = self.layer_2(c1)
  44. c3 = self.layer_3(c2)
  45. c4 = self.layer_4(c3)
  46. c5 = self.layer_5(c4)
  47. outputs = [c3, c4, c5]
  48. return outputs
  49. # --------------------- Functions -----------------------
  50. def build_backbone(model_name='darknet53', pretrained=False):
  51. """Constructs a darknet-53 model.
  52. Args:
  53. pretrained (bool): If True, returns a model pre-trained on ImageNet
  54. """
  55. if model_name == 'darknet53':
  56. backbone = DarkNet53(act_type='silu', norm_type='BN')
  57. feat_dims = backbone.feat_dims
  58. if pretrained:
  59. url = model_urls['darknet53']
  60. if url is not None:
  61. print('Loading pretrained weight ...')
  62. checkpoint = torch.hub.load_state_dict_from_url(
  63. url=url, map_location="cpu", check_hash=True)
  64. # checkpoint state dict
  65. checkpoint_state_dict = checkpoint.pop("model")
  66. # model state dict
  67. model_state_dict = backbone.state_dict()
  68. # check
  69. for k in list(checkpoint_state_dict.keys()):
  70. if k in model_state_dict:
  71. shape_model = tuple(model_state_dict[k].shape)
  72. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  73. if shape_model != shape_checkpoint:
  74. checkpoint_state_dict.pop(k)
  75. else:
  76. checkpoint_state_dict.pop(k)
  77. print(k)
  78. backbone.load_state_dict(checkpoint_state_dict)
  79. else:
  80. print('No backbone pretrained: DarkNet53')
  81. return backbone, feat_dims
  82. if __name__ == '__main__':
  83. import time
  84. from thop import profile
  85. model, feats = build_backbone(pretrained=False)
  86. x = torch.randn(1, 3, 224, 224)
  87. t0 = time.time()
  88. outputs = model(x)
  89. t1 = time.time()
  90. print('Time: ', t1 - t0)
  91. for out in outputs:
  92. print(out.shape)
  93. x = torch.randn(1, 3, 224, 224)
  94. print('==============================')
  95. flops, params = profile(model, inputs=(x, ), verbose=False)
  96. print('==============================')
  97. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  98. print('Params : {:.2f} M'.format(params / 1e6))