yolov2_backbone.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import torch
  2. import torch.nn as nn
  3. import os
  4. model_urls = {
  5. "darknet19": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/darknet19.pth",
  6. }
  7. __all__ = ['darknet19']
  8. class Conv_BN_LeakyReLU(nn.Module):
  9. def __init__(self, in_channels, out_channels, ksize, padding=0, stride=1, dilation=1):
  10. super(Conv_BN_LeakyReLU, self).__init__()
  11. self.convs = nn.Sequential(
  12. nn.Conv2d(in_channels, out_channels, ksize, padding=padding, stride=stride, dilation=dilation),
  13. nn.BatchNorm2d(out_channels),
  14. nn.LeakyReLU(0.1, inplace=True)
  15. )
  16. def forward(self, x):
  17. return self.convs(x)
  18. class DarkNet19(nn.Module):
  19. def __init__(self):
  20. super(DarkNet19, self).__init__()
  21. # backbone network : DarkNet-19
  22. # output : stride = 2, c = 32
  23. self.conv_1 = nn.Sequential(
  24. Conv_BN_LeakyReLU(3, 32, 3, 1),
  25. nn.MaxPool2d((2,2), 2),
  26. )
  27. # output : stride = 4, c = 64
  28. self.conv_2 = nn.Sequential(
  29. Conv_BN_LeakyReLU(32, 64, 3, 1),
  30. nn.MaxPool2d((2,2), 2)
  31. )
  32. # output : stride = 8, c = 128
  33. self.conv_3 = nn.Sequential(
  34. Conv_BN_LeakyReLU(64, 128, 3, 1),
  35. Conv_BN_LeakyReLU(128, 64, 1),
  36. Conv_BN_LeakyReLU(64, 128, 3, 1),
  37. nn.MaxPool2d((2,2), 2)
  38. )
  39. # output : stride = 8, c = 256
  40. self.conv_4 = nn.Sequential(
  41. Conv_BN_LeakyReLU(128, 256, 3, 1),
  42. Conv_BN_LeakyReLU(256, 128, 1),
  43. Conv_BN_LeakyReLU(128, 256, 3, 1),
  44. )
  45. # output : stride = 16, c = 512
  46. self.maxpool_4 = nn.MaxPool2d((2, 2), 2)
  47. self.conv_5 = nn.Sequential(
  48. Conv_BN_LeakyReLU(256, 512, 3, 1),
  49. Conv_BN_LeakyReLU(512, 256, 1),
  50. Conv_BN_LeakyReLU(256, 512, 3, 1),
  51. Conv_BN_LeakyReLU(512, 256, 1),
  52. Conv_BN_LeakyReLU(256, 512, 3, 1),
  53. )
  54. # output : stride = 32, c = 1024
  55. self.maxpool_5 = nn.MaxPool2d((2, 2), 2)
  56. self.conv_6 = nn.Sequential(
  57. Conv_BN_LeakyReLU(512, 1024, 3, 1),
  58. Conv_BN_LeakyReLU(1024, 512, 1),
  59. Conv_BN_LeakyReLU(512, 1024, 3, 1),
  60. Conv_BN_LeakyReLU(1024, 512, 1),
  61. Conv_BN_LeakyReLU(512, 1024, 3, 1)
  62. )
  63. def forward(self, x):
  64. c1 = self.conv_1(x) # c1
  65. c2 = self.conv_2(c1) # c2
  66. c3 = self.conv_3(c2) # c3
  67. c3 = self.conv_4(c3) # c3
  68. c4 = self.conv_5(self.maxpool_4(c3)) # c4
  69. c5 = self.conv_6(self.maxpool_5(c4)) # c5
  70. return c5
  71. def build_darknet19(pretrained=False):
  72. # model
  73. model = DarkNet19()
  74. feat_dim = 1024
  75. # load weight
  76. if pretrained:
  77. print('Loading pretrained weight ...')
  78. url = model_urls['darknet19']
  79. # checkpoint state dict
  80. checkpoint_state_dict = torch.hub.load_state_dict_from_url(
  81. url=url, map_location="cpu", check_hash=True)
  82. # model state dict
  83. model_state_dict = model.state_dict()
  84. # check
  85. for k in list(checkpoint_state_dict.keys()):
  86. if k in model_state_dict:
  87. shape_model = tuple(model_state_dict[k].shape)
  88. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  89. if shape_model != shape_checkpoint:
  90. checkpoint_state_dict.pop(k)
  91. else:
  92. checkpoint_state_dict.pop(k)
  93. print(k)
  94. model.load_state_dict(checkpoint_state_dict)
  95. return model, feat_dim
  96. if __name__ == '__main__':
  97. import time
  98. model, feat_dim = build_darknet19(pretrained=True)
  99. x = torch.randn(1, 3, 224, 224)
  100. t0 = time.time()
  101. y = model(x)
  102. t1 = time.time()
  103. print('Time: ', t1 - t0)