yolov3_backbone.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolov3_basic import BasicConv, ResBlock
  5. except:
  6. from yolov3_basic import BasicConv, ResBlock
  7. # IN1K pretrained weight
  8. pretrained_urls = {
  9. 'n': None,
  10. 's': "https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/darknet_s_in1k_68.5.pth",
  11. 'm': None,
  12. 'l': None,
  13. 'x': None,
  14. }
  15. # --------------------- Yolov3's Backbone -----------------------
  16. ## Modified DarkNet
  17. class Yolov3Backbone(nn.Module):
  18. def __init__(self, cfg):
  19. super(Yolov3Backbone, self).__init__()
  20. # ------------------ Basic setting ------------------
  21. self.model_scale = cfg.scale
  22. self.feat_dims = [round(64 * cfg.width),
  23. round(128 * cfg.width),
  24. round(256 * cfg.width),
  25. round(512 * cfg.width),
  26. round(1024 * cfg.width)]
  27. # ------------------ Network setting ------------------
  28. ## P1/2
  29. self.layer_1 = BasicConv(3, self.feat_dims[0],
  30. kernel_size=6, padding=2, stride=2,
  31. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise)
  32. # P2/4
  33. self.layer_2 = nn.Sequential(
  34. BasicConv(self.feat_dims[0], self.feat_dims[1],
  35. kernel_size=3, padding=1, stride=2,
  36. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  37. ResBlock(in_dim = self.feat_dims[1],
  38. out_dim = self.feat_dims[1],
  39. num_blocks = round(3*cfg.depth),
  40. expansion = 0.5,
  41. shortcut = True,
  42. act_type = cfg.bk_act,
  43. norm_type = cfg.bk_norm,
  44. depthwise = cfg.bk_depthwise)
  45. )
  46. # P3/8
  47. self.layer_3 = nn.Sequential(
  48. BasicConv(self.feat_dims[1], self.feat_dims[2],
  49. kernel_size=3, padding=1, stride=2,
  50. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  51. ResBlock(in_dim = self.feat_dims[2],
  52. out_dim = self.feat_dims[2],
  53. num_blocks = round(9*cfg.depth),
  54. expansion = 0.5,
  55. shortcut = True,
  56. act_type = cfg.bk_act,
  57. norm_type = cfg.bk_norm,
  58. depthwise = cfg.bk_depthwise)
  59. )
  60. # P4/16
  61. self.layer_4 = nn.Sequential(
  62. BasicConv(self.feat_dims[2], self.feat_dims[3],
  63. kernel_size=3, padding=1, stride=2,
  64. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  65. ResBlock(in_dim = self.feat_dims[3],
  66. out_dim = self.feat_dims[3],
  67. num_blocks = round(9*cfg.depth),
  68. expansion = 0.5,
  69. shortcut = True,
  70. act_type = cfg.bk_act,
  71. norm_type = cfg.bk_norm,
  72. depthwise = cfg.bk_depthwise)
  73. )
  74. # P5/32
  75. self.layer_5 = nn.Sequential(
  76. BasicConv(self.feat_dims[3], self.feat_dims[4],
  77. kernel_size=3, padding=1, stride=2,
  78. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  79. ResBlock(in_dim = self.feat_dims[4],
  80. out_dim = self.feat_dims[4],
  81. num_blocks = round(3*cfg.depth),
  82. expansion = 0.5,
  83. shortcut = True,
  84. act_type = cfg.bk_act,
  85. norm_type = cfg.bk_norm,
  86. depthwise = cfg.bk_depthwise)
  87. )
  88. # Initialize all layers
  89. self.init_weights()
  90. # Load imagenet pretrained weight
  91. if cfg.use_pretrained:
  92. self.load_pretrained()
  93. def init_weights(self):
  94. """Initialize the parameters."""
  95. for m in self.modules():
  96. if isinstance(m, torch.nn.Conv2d):
  97. # In order to be consistent with the source code,
  98. # reset the Conv2d initialization parameters
  99. m.reset_parameters()
  100. def load_pretrained(self):
  101. url = pretrained_urls[self.model_scale]
  102. if url is not None:
  103. print('Loading backbone pretrained weight from : {}'.format(url))
  104. # checkpoint state dict
  105. checkpoint = torch.hub.load_state_dict_from_url(
  106. url=url, map_location="cpu", check_hash=True)
  107. checkpoint_state_dict = checkpoint.pop("model")
  108. # model state dict
  109. model_state_dict = self.state_dict()
  110. # check
  111. for k in list(checkpoint_state_dict.keys()):
  112. if k in model_state_dict:
  113. shape_model = tuple(model_state_dict[k].shape)
  114. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  115. if shape_model != shape_checkpoint:
  116. checkpoint_state_dict.pop(k)
  117. else:
  118. checkpoint_state_dict.pop(k)
  119. print('Unused key: ', k)
  120. # load the weight
  121. self.load_state_dict(checkpoint_state_dict)
  122. else:
  123. print('No pretrained weight for model scale: {}.'.format(self.model_scale))
  124. def forward(self, x):
  125. c1 = self.layer_1(x)
  126. c2 = self.layer_2(c1)
  127. c3 = self.layer_3(c2)
  128. c4 = self.layer_4(c3)
  129. c5 = self.layer_5(c4)
  130. outputs = [c3, c4, c5]
  131. return outputs
  132. if __name__ == '__main__':
  133. import time
  134. from thop import profile
  135. class BaseConfig(object):
  136. def __init__(self) -> None:
  137. self.bk_act = 'silu'
  138. self.bk_norm = 'BN'
  139. self.bk_depthwise = False
  140. self.width = 0.5
  141. self.depth = 0.34
  142. self.scale = "s"
  143. self.use_pretrained = True
  144. cfg = BaseConfig()
  145. model = Yolov3Backbone(cfg)
  146. x = torch.randn(1, 3, 640, 640)
  147. t0 = time.time()
  148. outputs = model(x)
  149. t1 = time.time()
  150. print('Time: ', t1 - t0)
  151. for out in outputs:
  152. print(out.shape)
  153. x = torch.randn(1, 3, 640, 640)
  154. flops, params = profile(model, inputs=(x, ), verbose=False)
  155. print('==============================')
  156. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  157. print('Params : {:.2f} M'.format(params / 1e6))