yolov3_backbone.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import ConvModule, ResBlock
  5. except:
  6. from modules import ConvModule, ResBlock
  7. in1k_pretrained_urls = {
  8. "darknet53": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/darknet53_silu.pth",
  9. }
  10. # --------------------- Yolov3 backbone (DarkNet-53 with SiLU) -----------------------
  11. class Yolov3Backbone(nn.Module):
  12. def __init__(self, use_pretrained: bool = False):
  13. super(Yolov3Backbone, self).__init__()
  14. self.feat_dims = [256, 512, 1024]
  15. self.use_pretrained = use_pretrained
  16. # P1
  17. self.layer_1 = nn.Sequential(
  18. ConvModule(3, 32, kernel_size=3),
  19. ConvModule(32, 64, kernel_size=3, stride=2),
  20. ResBlock(64, 64, num_blocks=1)
  21. )
  22. # P2
  23. self.layer_2 = nn.Sequential(
  24. ConvModule(64, 128, kernel_size=3, stride=2),
  25. ResBlock(128, 128, num_blocks=2)
  26. )
  27. # P3
  28. self.layer_3 = nn.Sequential(
  29. ConvModule(128, 256, kernel_size=3, stride=2),
  30. ResBlock(256, 256, num_blocks=8)
  31. )
  32. # P4
  33. self.layer_4 = nn.Sequential(
  34. ConvModule(256, 512, kernel_size=3, stride=2),
  35. ResBlock(512, 512, num_blocks=8)
  36. )
  37. # P5
  38. self.layer_5 = nn.Sequential(
  39. ConvModule(512, 1024, kernel_size=3, stride=2),
  40. ResBlock(1024, 1024, num_blocks=4)
  41. )
  42. # Initialize all layers
  43. self.init_weights()
  44. def init_weights(self):
  45. """Initialize the parameters."""
  46. for m in self.modules():
  47. if isinstance(m, torch.nn.Conv2d):
  48. m.reset_parameters()
  49. # Load imagenet pretrained weight
  50. if self.use_pretrained:
  51. self.load_pretrained()
  52. def load_pretrained(self):
  53. url = in1k_pretrained_urls["darknet53"]
  54. if url is not None:
  55. print('Loading backbone pretrained weight from : {}'.format(url))
  56. # checkpoint state dict
  57. checkpoint = torch.hub.load_state_dict_from_url(
  58. url=url, map_location="cpu", check_hash=True)
  59. checkpoint_state_dict = checkpoint.pop("model")
  60. # model state dict
  61. model_state_dict = self.state_dict()
  62. # check
  63. for k in list(checkpoint_state_dict.keys()):
  64. if k in model_state_dict:
  65. shape_model = tuple(model_state_dict[k].shape)
  66. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  67. if shape_model != shape_checkpoint:
  68. checkpoint_state_dict.pop(k)
  69. else:
  70. checkpoint_state_dict.pop(k)
  71. print('Unused key: ', k)
  72. # load the weight
  73. self.load_state_dict(checkpoint_state_dict)
  74. else:
  75. print('No pretrained weight for model scale: {}.'.format(self.model_scale))
  76. def forward(self, x):
  77. c1 = self.layer_1(x)
  78. c2 = self.layer_2(c1)
  79. c3 = self.layer_3(c2)
  80. c4 = self.layer_4(c3)
  81. c5 = self.layer_5(c4)
  82. outputs = [c3, c4, c5]
  83. return outputs
  84. if __name__=='__main__':
  85. from thop import profile
  86. # Build backbone
  87. model = Yolov3Backbone(use_pretrained=True)
  88. # Randomly generate a input data
  89. x = torch.randn(2, 3, 640, 640)
  90. # Inference
  91. outputs = model(x)
  92. print(' - the shape of input : ', x.shape)
  93. for out in outputs:
  94. print(' - the shape of output : ', out.shape)
  95. x = torch.randn(1, 3, 640, 640)
  96. flops, params = profile(model, inputs=(x, ), verbose=False)
  97. print('============== FLOPs & Params ================')
  98. print(' - FLOPs : {:.2f} G'.format(flops / 1e9 * 2))
  99. print(' - Params : {:.2f} M'.format(params / 1e6))