yolov2_backbone.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import torch
  2. import torch.nn as nn
  3. model_urls = {
  4. "darknet19": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/darknet19.pth",
  5. }
  6. __all__ = ['darknet19']
  7. # --------------------- Basic Module -----------------------
  8. class Conv_BN_LeakyReLU(nn.Module):
  9. def __init__(self, in_channels, out_channels, ksize, padding=0, stride=1, dilation=1):
  10. super(Conv_BN_LeakyReLU, self).__init__()
  11. self.convs = nn.Sequential(
  12. nn.Conv2d(in_channels, out_channels, ksize, padding=padding, stride=stride, dilation=dilation),
  13. nn.BatchNorm2d(out_channels),
  14. nn.LeakyReLU(0.1, inplace=True)
  15. )
  16. def forward(self, x):
  17. return self.convs(x)
  18. # --------------------- DarkNet-19 -----------------------
  19. class DarkNet19(nn.Module):
  20. def __init__(self):
  21. super(DarkNet19, self).__init__()
  22. # backbone network : DarkNet-19
  23. # output : stride = 2, c = 32
  24. self.conv_1 = nn.Sequential(
  25. Conv_BN_LeakyReLU(3, 32, 3, 1),
  26. nn.MaxPool2d((2,2), 2),
  27. )
  28. # output : stride = 4, c = 64
  29. self.conv_2 = nn.Sequential(
  30. Conv_BN_LeakyReLU(32, 64, 3, 1),
  31. nn.MaxPool2d((2,2), 2)
  32. )
  33. # output : stride = 8, c = 128
  34. self.conv_3 = nn.Sequential(
  35. Conv_BN_LeakyReLU(64, 128, 3, 1),
  36. Conv_BN_LeakyReLU(128, 64, 1),
  37. Conv_BN_LeakyReLU(64, 128, 3, 1),
  38. nn.MaxPool2d((2,2), 2)
  39. )
  40. # output : stride = 8, c = 256
  41. self.conv_4 = nn.Sequential(
  42. Conv_BN_LeakyReLU(128, 256, 3, 1),
  43. Conv_BN_LeakyReLU(256, 128, 1),
  44. Conv_BN_LeakyReLU(128, 256, 3, 1),
  45. )
  46. # output : stride = 16, c = 512
  47. self.maxpool_4 = nn.MaxPool2d((2, 2), 2)
  48. self.conv_5 = nn.Sequential(
  49. Conv_BN_LeakyReLU(256, 512, 3, 1),
  50. Conv_BN_LeakyReLU(512, 256, 1),
  51. Conv_BN_LeakyReLU(256, 512, 3, 1),
  52. Conv_BN_LeakyReLU(512, 256, 1),
  53. Conv_BN_LeakyReLU(256, 512, 3, 1),
  54. )
  55. # output : stride = 32, c = 1024
  56. self.maxpool_5 = nn.MaxPool2d((2, 2), 2)
  57. self.conv_6 = nn.Sequential(
  58. Conv_BN_LeakyReLU(512, 1024, 3, 1),
  59. Conv_BN_LeakyReLU(1024, 512, 1),
  60. Conv_BN_LeakyReLU(512, 1024, 3, 1),
  61. Conv_BN_LeakyReLU(1024, 512, 1),
  62. Conv_BN_LeakyReLU(512, 1024, 3, 1)
  63. )
  64. def forward(self, x):
  65. c1 = self.conv_1(x) # c1
  66. c2 = self.conv_2(c1) # c2
  67. c3 = self.conv_3(c2) # c3
  68. c3 = self.conv_4(c3) # c3
  69. c4 = self.conv_5(self.maxpool_4(c3)) # c4
  70. c5 = self.conv_6(self.maxpool_5(c4)) # c5
  71. return c5
  72. # --------------------- Fsnctions -----------------------
  73. def build_backbone(model_name='darknet19', pretrained=False):
  74. if model_name == 'darknet19':
  75. # model
  76. model = DarkNet19()
  77. feat_dim = 1024
  78. # load weight
  79. if pretrained:
  80. print('Loading pretrained weight ...')
  81. url = model_urls['darknet19']
  82. # checkpoint state dict
  83. checkpoint_state_dict = torch.hub.load_state_dict_from_url(
  84. url=url, map_location="cpu", check_hash=True)
  85. # model state dict
  86. model_state_dict = model.state_dict()
  87. # check
  88. for k in list(checkpoint_state_dict.keys()):
  89. if k in model_state_dict:
  90. shape_model = tuple(model_state_dict[k].shape)
  91. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  92. if shape_model != shape_checkpoint:
  93. checkpoint_state_dict.pop(k)
  94. else:
  95. checkpoint_state_dict.pop(k)
  96. print(k)
  97. model.load_state_dict(checkpoint_state_dict)
  98. return model, feat_dim
  99. if __name__ == '__main__':
  100. import time
  101. model, feat_dim = build_backbone(pretrained=True)
  102. x = torch.randn(1, 3, 224, 224)
  103. t0 = time.time()
  104. y = model(x)
  105. t1 = time.time()
  106. print('Time: ', t1 - t0)