yolox_backbone.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import ConvModule, CSPBlock
  5. except:
  6. from modules import ConvModule, CSPBlock
  7. # IN1K pretrained weight
  8. pretrained_urls = {
  9. 'n': None,
  10. 's': "https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/in1k_pretrained_weight/cspdarknet_s_in1k_70.1.pth",
  11. 'm': "https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/in1k_pretrained_weight/cspdarknet_m_in1k_75.2.pth",
  12. 'l': None,
  13. 'x': None,
  14. }
  15. # --------------------- Yolov3's Backbone -----------------------
  16. ## Modified DarkNet
  17. class YoloxBackbone(nn.Module):
  18. def __init__(self, cfg):
  19. super(YoloxBackbone, self).__init__()
  20. # ------------------ Basic setting ------------------
  21. self.model_scale = cfg.model_scale
  22. self.feat_dims = [round(64 * cfg.width),
  23. round(128 * cfg.width),
  24. round(256 * cfg.width),
  25. round(512 * cfg.width),
  26. round(1024 * cfg.width)]
  27. # ------------------ Network setting ------------------
  28. ## P1/2
  29. self.layer_1 = ConvModule(3, self.feat_dims[0], kernel_size=6, padding=2, stride=2)
  30. # P2/4
  31. self.layer_2 = nn.Sequential(
  32. ConvModule(self.feat_dims[0], self.feat_dims[1], kernel_size=3, padding=1, stride=2),
  33. CSPBlock(in_dim = self.feat_dims[1],
  34. out_dim = self.feat_dims[1],
  35. num_blocks = round(3*cfg.depth),
  36. expansion = 0.5,
  37. shortcut = True,
  38. )
  39. )
  40. # P3/8
  41. self.layer_3 = nn.Sequential(
  42. ConvModule(self.feat_dims[1], self.feat_dims[2], kernel_size=3, padding=1, stride=2),
  43. CSPBlock(in_dim = self.feat_dims[2],
  44. out_dim = self.feat_dims[2],
  45. num_blocks = round(9*cfg.depth),
  46. expansion = 0.5,
  47. shortcut = True,
  48. )
  49. )
  50. # P4/16
  51. self.layer_4 = nn.Sequential(
  52. ConvModule(self.feat_dims[2], self.feat_dims[3], kernel_size=3, padding=1, stride=2),
  53. CSPBlock(in_dim = self.feat_dims[3],
  54. out_dim = self.feat_dims[3],
  55. num_blocks = round(9*cfg.depth),
  56. expansion = 0.5,
  57. shortcut = True,
  58. )
  59. )
  60. # P5/32
  61. self.layer_5 = nn.Sequential(
  62. ConvModule(self.feat_dims[3], self.feat_dims[4], kernel_size=3, padding=1, stride=2),
  63. CSPBlock(in_dim = self.feat_dims[4],
  64. out_dim = self.feat_dims[4],
  65. num_blocks = round(3*cfg.depth),
  66. expansion = 0.5,
  67. shortcut = True,
  68. )
  69. )
  70. # Initialize all layers
  71. self.init_weights()
  72. # Load imagenet pretrained weight
  73. if cfg.use_pretrained:
  74. self.load_pretrained()
  75. def init_weights(self):
  76. """Initialize the parameters."""
  77. for m in self.modules():
  78. if isinstance(m, torch.nn.Conv2d):
  79. m.reset_parameters()
  80. def load_pretrained(self):
  81. url = pretrained_urls[self.model_scale]
  82. if url is not None:
  83. print('Loading backbone pretrained weight from : {}'.format(url))
  84. # checkpoint state dict
  85. checkpoint = torch.hub.load_state_dict_from_url(
  86. url=url, map_location="cpu", check_hash=True)
  87. checkpoint_state_dict = checkpoint.pop("model")
  88. # model state dict
  89. model_state_dict = self.state_dict()
  90. # check
  91. for k in list(checkpoint_state_dict.keys()):
  92. if k in model_state_dict:
  93. shape_model = tuple(model_state_dict[k].shape)
  94. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  95. if shape_model != shape_checkpoint:
  96. checkpoint_state_dict.pop(k)
  97. else:
  98. checkpoint_state_dict.pop(k)
  99. print('Unused key: ', k)
  100. # load the weight
  101. self.load_state_dict(checkpoint_state_dict)
  102. else:
  103. print('No pretrained weight for model scale: {}.'.format(self.model_scale))
  104. def forward(self, x):
  105. c1 = self.layer_1(x)
  106. c2 = self.layer_2(c1)
  107. c3 = self.layer_3(c2)
  108. c4 = self.layer_4(c3)
  109. c5 = self.layer_5(c4)
  110. outputs = [c3, c4, c5]
  111. return outputs
  112. if __name__ == '__main__':
  113. import time
  114. from thop import profile
  115. class BaseConfig(object):
  116. def __init__(self) -> None:
  117. self.width = 0.5
  118. self.depth = 0.34
  119. self.model_scale = "s"
  120. self.use_pretrained = True
  121. cfg = BaseConfig()
  122. model = YoloxBackbone(cfg)
  123. x = torch.randn(1, 3, 640, 640)
  124. t0 = time.time()
  125. outputs = model(x)
  126. t1 = time.time()
  127. print('Time: ', t1 - t0)
  128. for out in outputs:
  129. print(out.shape)
  130. x = torch.randn(1, 3, 640, 640)
  131. print('==============================')
  132. flops, params = profile(model, inputs=(x, ), verbose=False)
  133. print('==============================')
  134. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  135. print('Params : {:.2f} M'.format(params / 1e6))