gelan_backbone.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .gelan_basic import BasicConv, RepGElanLayer, ADown
  5. except:
  6. from gelan_basic import BasicConv, RepGElanLayer, ADown
  7. # IN1K pretrained weight
  8. pretrained_urls = {
  9. 's': "https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/gelan_s.pth",
  10. 'm': None,
  11. 'l': None,
  12. 'x': None,
  13. }
  14. # ---------------------------- Basic functions ----------------------------
  15. class GElanBackbone(nn.Module):
  16. def __init__(self, cfg):
  17. super(GElanBackbone, self).__init__()
  18. # ------------------ Basic setting ------------------
  19. self.model_scale = cfg.scale
  20. self.feat_dims = [cfg.backbone_feats["c1"][-1], # 64
  21. cfg.backbone_feats["c2"][-1], # 128
  22. cfg.backbone_feats["c3"][-1], # 256
  23. cfg.backbone_feats["c4"][-1], # 512
  24. cfg.backbone_feats["c5"][-1], # 512
  25. ]
  26. # ------------------ Network setting ------------------
  27. ## P1/2
  28. self.layer_1 = BasicConv(3, cfg.backbone_feats["c1"][0],
  29. kernel_size=3, padding=1, stride=2,
  30. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise)
  31. # P2/4
  32. self.layer_2 = nn.Sequential(
  33. BasicConv(cfg.backbone_feats["c1"][0], cfg.backbone_feats["c2"][0],
  34. kernel_size=3, padding=1, stride=2,
  35. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  36. RepGElanLayer(in_dim = cfg.backbone_feats["c2"][0],
  37. inter_dims = cfg.backbone_feats["c2"][1],
  38. out_dim = cfg.backbone_feats["c2"][2],
  39. num_blocks = cfg.backbone_depth,
  40. shortcut = True,
  41. act_type = cfg.bk_act,
  42. norm_type = cfg.bk_norm,
  43. depthwise = cfg.bk_depthwise)
  44. )
  45. # P3/8
  46. self.layer_3 = nn.Sequential(
  47. ADown(cfg.backbone_feats["c2"][2], cfg.backbone_feats["c3"][0],
  48. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  49. RepGElanLayer(in_dim = cfg.backbone_feats["c3"][0],
  50. inter_dims = cfg.backbone_feats["c3"][1],
  51. out_dim = cfg.backbone_feats["c3"][2],
  52. num_blocks = cfg.backbone_depth,
  53. shortcut = True,
  54. act_type = cfg.bk_act,
  55. norm_type = cfg.bk_norm,
  56. depthwise = cfg.bk_depthwise)
  57. )
  58. # P4/16
  59. self.layer_4 = nn.Sequential(
  60. ADown(cfg.backbone_feats["c3"][2], cfg.backbone_feats["c4"][0],
  61. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  62. RepGElanLayer(in_dim = cfg.backbone_feats["c4"][0],
  63. inter_dims = cfg.backbone_feats["c4"][1],
  64. out_dim = cfg.backbone_feats["c4"][2],
  65. num_blocks = cfg.backbone_depth,
  66. shortcut = True,
  67. act_type = cfg.bk_act,
  68. norm_type = cfg.bk_norm,
  69. depthwise = cfg.bk_depthwise)
  70. )
  71. # P5/32
  72. self.layer_5 = nn.Sequential(
  73. ADown(cfg.backbone_feats["c4"][2], cfg.backbone_feats["c5"][0],
  74. act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
  75. RepGElanLayer(in_dim = cfg.backbone_feats["c5"][0],
  76. inter_dims = cfg.backbone_feats["c5"][1],
  77. out_dim = cfg.backbone_feats["c5"][2],
  78. num_blocks = cfg.backbone_depth,
  79. shortcut = True,
  80. act_type = cfg.bk_act,
  81. norm_type = cfg.bk_norm,
  82. depthwise = cfg.bk_depthwise)
  83. )
  84. # Initialize all layers
  85. self.init_weights()
  86. # Load imagenet pretrained weight
  87. if cfg.use_pretrained:
  88. self.load_pretrained()
  89. def init_weights(self):
  90. """Initialize the parameters."""
  91. for m in self.modules():
  92. if isinstance(m, torch.nn.Conv2d):
  93. # In order to be consistent with the source code,
  94. # reset the Conv2d initialization parameters
  95. m.reset_parameters()
  96. def load_pretrained(self):
  97. url = pretrained_urls[self.model_scale]
  98. if url is not None:
  99. print('Loading backbone pretrained weight from : {}'.format(url))
  100. # checkpoint state dict
  101. checkpoint = torch.hub.load_state_dict_from_url(
  102. url=url, map_location="cpu", check_hash=True)
  103. checkpoint_state_dict = checkpoint.pop("model")
  104. # model state dict
  105. model_state_dict = self.state_dict()
  106. # check
  107. for k in list(checkpoint_state_dict.keys()):
  108. if k in model_state_dict:
  109. shape_model = tuple(model_state_dict[k].shape)
  110. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  111. if shape_model != shape_checkpoint:
  112. checkpoint_state_dict.pop(k)
  113. else:
  114. checkpoint_state_dict.pop(k)
  115. print('Unused key: ', k)
  116. # load the weight
  117. self.load_state_dict(checkpoint_state_dict)
  118. else:
  119. print('No pretrained weight for model scale: {}.'.format(self.model_scale))
  120. def forward(self, x):
  121. c1 = self.layer_1(x)
  122. c2 = self.layer_2(c1)
  123. c3 = self.layer_3(c2)
  124. c4 = self.layer_4(c3)
  125. c5 = self.layer_5(c4)
  126. outputs = [c3, c4, c5]
  127. return outputs
  128. # ---------------------------- Functions ----------------------------
  129. ## build Yolo's Backbone
  130. def build_backbone(cfg):
  131. # model
  132. if cfg.backbone == "gelan":
  133. backbone = GElanBackbone(cfg)
  134. else:
  135. raise NotImplementedError("Unknown gelan backbone: {}".format(cfg.backbone))
  136. return backbone
  137. if __name__ == '__main__':
  138. import time
  139. from thop import profile
  140. class BaseConfig(object):
  141. def __init__(self) -> None:
  142. self.backbone = 'gelan'
  143. self.use_pretrained = True
  144. self.bk_act = 'silu'
  145. self.bk_norm = 'BN'
  146. self.bk_depthwise = False
  147. # # Gelan-C scale
  148. # self.backbone_feats = {
  149. # "c1": [64],
  150. # "c2": [128, [128, 64], 256],
  151. # "c3": [256, [256, 128], 512],
  152. # "c4": [512, [512, 256], 512],
  153. # "c5": [512, [512, 256], 512],
  154. # }
  155. # self.scale = "l"
  156. # self.backbone_depth = 1
  157. # Gelan-S scale
  158. self.backbone_feats = {
  159. "c1": [32],
  160. "c2": [64, [64, 32], 64],
  161. "c3": [64, [64, 32], 128],
  162. "c4": [128, [128, 64], 256],
  163. "c5": [256, [256, 128], 256],
  164. }
  165. self.scale = "s"
  166. self.backbone_depth = 3
  167. cfg = BaseConfig()
  168. model = build_backbone(cfg)
  169. x = torch.randn(1, 3, 640, 640)
  170. t0 = time.time()
  171. outputs = model(x)
  172. t1 = time.time()
  173. print('Time: ', t1 - t0)
  174. for out in outputs:
  175. print(out.shape)
  176. print('==============================')
  177. flops, params = profile(model, inputs=(x, ), verbose=False)
  178. print('==============================')
  179. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  180. print('Params : {:.2f} M'.format(params / 1e6))