yolov2_backbone.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import torch
  2. import torch.nn as nn
  3. import os
  4. model_urls = {
  5. "darknet19": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/darknet19.pth",
  6. }
  7. __all__ = ['darknet19']
  8. class Conv_BN_LeakyReLU(nn.Module):
  9. def __init__(self, in_channels, out_channels, ksize, padding=0, stride=1, dilation=1):
  10. super(Conv_BN_LeakyReLU, self).__init__()
  11. self.convs = nn.Sequential(
  12. nn.Conv2d(in_channels, out_channels, ksize, padding=padding, stride=stride, dilation=dilation),
  13. nn.BatchNorm2d(out_channels),
  14. nn.LeakyReLU(0.1, inplace=True)
  15. )
  16. def forward(self, x):
  17. return self.convs(x)
  18. class DarkNet19(nn.Module):
  19. def __init__(self):
  20. super(DarkNet19, self).__init__()
  21. # backbone network : DarkNet-19
  22. # output : stride = 2, c = 32
  23. self.conv_1 = nn.Sequential(
  24. Conv_BN_LeakyReLU(3, 32, 3, 1),
  25. nn.MaxPool2d((2,2), 2),
  26. )
  27. # output : stride = 4, c = 64
  28. self.conv_2 = nn.Sequential(
  29. Conv_BN_LeakyReLU(32, 64, 3, 1),
  30. nn.MaxPool2d((2,2), 2)
  31. )
  32. # output : stride = 8, c = 128
  33. self.conv_3 = nn.Sequential(
  34. Conv_BN_LeakyReLU(64, 128, 3, 1),
  35. Conv_BN_LeakyReLU(128, 64, 1),
  36. Conv_BN_LeakyReLU(64, 128, 3, 1),
  37. nn.MaxPool2d((2,2), 2)
  38. )
  39. # output : stride = 8, c = 256
  40. self.conv_4 = nn.Sequential(
  41. Conv_BN_LeakyReLU(128, 256, 3, 1),
  42. Conv_BN_LeakyReLU(256, 128, 1),
  43. Conv_BN_LeakyReLU(128, 256, 3, 1),
  44. )
  45. # output : stride = 16, c = 512
  46. self.maxpool_4 = nn.MaxPool2d((2, 2), 2)
  47. self.conv_5 = nn.Sequential(
  48. Conv_BN_LeakyReLU(256, 512, 3, 1),
  49. Conv_BN_LeakyReLU(512, 256, 1),
  50. Conv_BN_LeakyReLU(256, 512, 3, 1),
  51. Conv_BN_LeakyReLU(512, 256, 1),
  52. Conv_BN_LeakyReLU(256, 512, 3, 1),
  53. )
  54. # output : stride = 32, c = 1024
  55. self.maxpool_5 = nn.MaxPool2d((2, 2), 2)
  56. self.conv_6 = nn.Sequential(
  57. Conv_BN_LeakyReLU(512, 1024, 3, 1),
  58. Conv_BN_LeakyReLU(1024, 512, 1),
  59. Conv_BN_LeakyReLU(512, 1024, 3, 1),
  60. Conv_BN_LeakyReLU(1024, 512, 1),
  61. Conv_BN_LeakyReLU(512, 1024, 3, 1)
  62. )
  63. def forward(self, x):
  64. c1 = self.conv_1(x) # c1
  65. c2 = self.conv_2(c1) # c2
  66. c3 = self.conv_3(c2) # c3
  67. c3 = self.conv_4(c3) # c3
  68. c4 = self.conv_5(self.maxpool_4(c3)) # c4
  69. c5 = self.conv_6(self.maxpool_5(c4)) # c5
  70. return c5
  71. def build_backbone(model_name='darknet19', pretrained=False):
  72. if model_name == 'darknet19':
  73. # model
  74. model = DarkNet19()
  75. feat_dim = 1024
  76. # load weight
  77. if pretrained:
  78. print('Loading pretrained weight ...')
  79. url = model_urls['darknet19']
  80. # checkpoint state dict
  81. checkpoint_state_dict = torch.hub.load_state_dict_from_url(
  82. url=url, map_location="cpu", check_hash=True)
  83. # model state dict
  84. model_state_dict = model.state_dict()
  85. # check
  86. for k in list(checkpoint_state_dict.keys()):
  87. if k in model_state_dict:
  88. shape_model = tuple(model_state_dict[k].shape)
  89. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  90. if shape_model != shape_checkpoint:
  91. checkpoint_state_dict.pop(k)
  92. else:
  93. checkpoint_state_dict.pop(k)
  94. print(k)
  95. model.load_state_dict(checkpoint_state_dict)
  96. return model, feat_dim
  97. if __name__ == '__main__':
  98. import time
  99. model, feat_dim = build_backbone(pretrained=True)
  100. x = torch.randn(1, 3, 224, 224)
  101. t0 = time.time()
  102. y = model(x)
  103. t1 = time.time()
  104. print('Time: ', t1 - t0)