gelan_backbone.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .gelan_basic import BasicConv, RepGElanLayer, ADown
  5. except:
  6. from gelan_basic import BasicConv, RepGElanLayer, ADown
  7. # IN1K pretrained weight
  8. pretrained_urls = {
  9. 's': "https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/gelan_s_in1k_68.4.pth",
  10. 'c': None,
  11. }
  12. # ---------------------------- Basic functions ----------------------------
  13. class GElanBackbone(nn.Module):
  14. def __init__(self, cfg):
  15. super(GElanBackbone, self).__init__()
  16. # ------------------ Basic setting ------------------
  17. self.model_scale = cfg.scale
  18. self.feat_dims = [cfg.backbone_feats["c1"][-1], # 64
  19. cfg.backbone_feats["c2"][-1], # 128
  20. cfg.backbone_feats["c3"][-1], # 256
  21. cfg.backbone_feats["c4"][-1], # 512
  22. cfg.backbone_feats["c5"][-1], # 512
  23. ]
  24. # ------------------ Network setting ------------------
  25. ## P1/2
  26. self.layer_1 = BasicConv(3, cfg.backbone_feats["c1"][0],
  27. kernel_size=3, padding=1, stride=2,
  28. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise)
  29. # P2/4
  30. self.layer_2 = nn.Sequential(
  31. BasicConv(cfg.backbone_feats["c1"][0], cfg.backbone_feats["c2"][0],
  32. kernel_size=3, padding=1, stride=2,
  33. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  34. RepGElanLayer(in_dim = cfg.backbone_feats["c2"][0],
  35. inter_dims = cfg.backbone_feats["c2"][1],
  36. out_dim = cfg.backbone_feats["c2"][2],
  37. num_blocks = cfg.backbone_depth,
  38. shortcut = True,
  39. act_type = cfg.bk_act,
  40. norm_type = cfg.bk_norm,
  41. depthwise = cfg.bk_depthwise)
  42. )
  43. # P3/8
  44. self.layer_3 = nn.Sequential(
  45. ADown(cfg.backbone_feats["c2"][2], cfg.backbone_feats["c3"][0],
  46. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  47. RepGElanLayer(in_dim = cfg.backbone_feats["c3"][0],
  48. inter_dims = cfg.backbone_feats["c3"][1],
  49. out_dim = cfg.backbone_feats["c3"][2],
  50. num_blocks = cfg.backbone_depth,
  51. shortcut = True,
  52. act_type = cfg.bk_act,
  53. norm_type = cfg.bk_norm,
  54. depthwise = cfg.bk_depthwise)
  55. )
  56. # P4/16
  57. self.layer_4 = nn.Sequential(
  58. ADown(cfg.backbone_feats["c3"][2], cfg.backbone_feats["c4"][0],
  59. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  60. RepGElanLayer(in_dim = cfg.backbone_feats["c4"][0],
  61. inter_dims = cfg.backbone_feats["c4"][1],
  62. out_dim = cfg.backbone_feats["c4"][2],
  63. num_blocks = cfg.backbone_depth,
  64. shortcut = True,
  65. act_type = cfg.bk_act,
  66. norm_type = cfg.bk_norm,
  67. depthwise = cfg.bk_depthwise)
  68. )
  69. # P5/32
  70. self.layer_5 = nn.Sequential(
  71. ADown(cfg.backbone_feats["c4"][2], cfg.backbone_feats["c5"][0],
  72. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  73. RepGElanLayer(in_dim = cfg.backbone_feats["c5"][0],
  74. inter_dims = cfg.backbone_feats["c5"][1],
  75. out_dim = cfg.backbone_feats["c5"][2],
  76. num_blocks = cfg.backbone_depth,
  77. shortcut = True,
  78. act_type = cfg.bk_act,
  79. norm_type = cfg.bk_norm,
  80. depthwise = cfg.bk_depthwise)
  81. )
  82. # Initialize all layers
  83. self.init_weights()
  84. # Load imagenet pretrained weight
  85. if cfg.use_pretrained:
  86. self.load_pretrained()
  87. def init_weights(self):
  88. """Initialize the parameters."""
  89. for m in self.modules():
  90. if isinstance(m, torch.nn.Conv2d):
  91. # In order to be consistent with the source code,
  92. # reset the Conv2d initialization parameters
  93. m.reset_parameters()
  94. def load_pretrained(self):
  95. url = pretrained_urls[self.model_scale]
  96. if url is not None:
  97. print('Loading backbone pretrained weight from : {}'.format(url))
  98. # checkpoint state dict
  99. checkpoint = torch.hub.load_state_dict_from_url(
  100. url=url, map_location="cpu", check_hash=True)
  101. checkpoint_state_dict = checkpoint.pop("model")
  102. # model state dict
  103. model_state_dict = self.state_dict()
  104. # check
  105. for k in list(checkpoint_state_dict.keys()):
  106. if k in model_state_dict:
  107. shape_model = tuple(model_state_dict[k].shape)
  108. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  109. if shape_model != shape_checkpoint:
  110. checkpoint_state_dict.pop(k)
  111. else:
  112. checkpoint_state_dict.pop(k)
  113. print('Unused key: ', k)
  114. # load the weight
  115. self.load_state_dict(checkpoint_state_dict)
  116. else:
  117. print('No pretrained weight for model scale: {}.'.format(self.model_scale))
  118. def forward(self, x):
  119. c1 = self.layer_1(x)
  120. c2 = self.layer_2(c1)
  121. c3 = self.layer_3(c2)
  122. c4 = self.layer_4(c3)
  123. c5 = self.layer_5(c4)
  124. outputs = [c3, c4, c5]
  125. return outputs
  126. # ---------------------------- Functions ----------------------------
  127. ## build Yolo's Backbone
  128. def build_backbone(cfg):
  129. # model
  130. if cfg.backbone == "gelan":
  131. backbone = GElanBackbone(cfg)
  132. else:
  133. raise NotImplementedError("Unknown gelan backbone: {}".format(cfg.backbone))
  134. return backbone
  135. if __name__ == '__main__':
  136. import time
  137. from thop import profile
  138. class BaseConfig(object):
  139. def __init__(self) -> None:
  140. self.backbone = 'gelan'
  141. self.use_pretrained = True
  142. self.bk_act = 'silu'
  143. self.bk_norm = 'BN'
  144. self.bk_depthwise = False
  145. # # Gelan-C scale
  146. # self.backbone_feats = {
  147. # "c1": [64],
  148. # "c2": [128, [128, 64], 256],
  149. # "c3": [256, [256, 128], 512],
  150. # "c4": [512, [512, 256], 512],
  151. # "c5": [512, [512, 256], 512],
  152. # }
  153. # self.scale = "l"
  154. # self.backbone_depth = 1
  155. # Gelan-S scale
  156. self.backbone_feats = {
  157. "c1": [32],
  158. "c2": [64, [64, 32], 64],
  159. "c3": [64, [64, 32], 128],
  160. "c4": [128, [128, 64], 256],
  161. "c5": [256, [256, 128], 256],
  162. }
  163. self.scale = "s"
  164. self.backbone_depth = 3
  165. cfg = BaseConfig()
  166. model = build_backbone(cfg)
  167. x = torch.randn(1, 3, 640, 640)
  168. t0 = time.time()
  169. outputs = model(x)
  170. t1 = time.time()
  171. print('Time: ', t1 - t0)
  172. for out in outputs:
  173. print(out.shape)
  174. print('==============================')
  175. flops, params = profile(model, inputs=(x, ), verbose=False)
  176. print('==============================')
  177. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  178. print('Params : {:.2f} M'.format(params / 1e6))