rtcdet_backbone.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .rtcdet_basic import Conv, RTCBlock
  5. except:
  6. from rtcdet_basic import Conv, RTCBlock
  7. # Pretrained weights
  8. model_urls = {
  9. # IN-1k classification pretrained
  10. "rtcnet_n": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elan_cspnet_nano.pth",
  11. "rtcnet_s": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elan_cspnet_small.pth",
  12. "rtcnet_m": None,
  13. "rtcnet_l": None,
  14. "rtcnet_x": None,
  15. # IN-1k MIM pretrained
  16. "mae_rtcnet_n": None,
  17. "mae_rtcnet_s": None,
  18. "mae_rtcnet_m": None,
  19. "mae_rtcnet_l": None,
  20. "mae_rtcnet_x": None,
  21. }
  22. # ---------------------------- Basic functions ----------------------------
  23. ## Real-time Convolutional Backbone
  24. class RTCBackbone(nn.Module):
  25. def __init__(self, width=1.0, depth=1.0, ratio=1.0, act_type='silu', norm_type='BN', depthwise=False):
  26. super(RTCBackbone, self).__init__()
  27. # ---------------- Basic parameters ----------------
  28. self.width_factor = width
  29. self.depth_factor = depth
  30. self.last_stage_factor = ratio
  31. self.feat_dims = [round(64 * width), round(128 * width), round(256 * width), round(512 * width), round(512 * width * ratio)]
  32. # ---------------- Network parameters ----------------
  33. ## P1/2
  34. self.layer_1 = Conv(3, self.feat_dims[0], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type)
  35. ## P2/4
  36. self.layer_2 = nn.Sequential(
  37. Conv(self.feat_dims[0], self.feat_dims[1], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  38. RTCBlock(in_dim = self.feat_dims[1],
  39. out_dim = self.feat_dims[1],
  40. num_blocks = round(3*depth),
  41. shortcut = True,
  42. act_type = act_type,
  43. norm_type = norm_type,
  44. depthwise = depthwise)
  45. )
  46. ## P3/8
  47. self.layer_3 = nn.Sequential(
  48. Conv(self.feat_dims[1], self.feat_dims[2], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  49. RTCBlock(in_dim = self.feat_dims[2],
  50. out_dim = self.feat_dims[2],
  51. num_blocks = round(6*depth),
  52. shortcut = True,
  53. act_type = act_type,
  54. norm_type = norm_type,
  55. depthwise = depthwise)
  56. )
  57. ## P4/16
  58. self.layer_4 = nn.Sequential(
  59. Conv(self.feat_dims[2], self.feat_dims[3], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  60. RTCBlock(in_dim = self.feat_dims[3],
  61. out_dim = self.feat_dims[3],
  62. num_blocks = round(6*depth),
  63. shortcut = True,
  64. act_type = act_type,
  65. norm_type = norm_type,
  66. depthwise = depthwise)
  67. )
  68. ## P5/32
  69. self.layer_5 = nn.Sequential(
  70. Conv(self.feat_dims[3], self.feat_dims[4], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
  71. RTCBlock(in_dim = self.feat_dims[4],
  72. out_dim = self.feat_dims[4],
  73. num_blocks = round(3*depth),
  74. shortcut = True,
  75. act_type = act_type,
  76. norm_type = norm_type,
  77. depthwise = depthwise)
  78. )
  79. def forward(self, x):
  80. c1 = self.layer_1(x)
  81. c2 = self.layer_2(c1)
  82. c3 = self.layer_3(c2)
  83. c4 = self.layer_4(c3)
  84. c5 = self.layer_5(c4)
  85. outputs = [c3, c4, c5]
  86. return outputs
  87. # ---------------------------- Functions ----------------------------
  88. ## Build Backbone network
  89. def build_backbone(cfg, pretrained=False):
  90. # build backbone model
  91. backbone = RTCBackbone(width=cfg['width'],
  92. depth=cfg['depth'],
  93. ratio=cfg['ratio'],
  94. act_type=cfg['bk_act'],
  95. norm_type=cfg['bk_norm'],
  96. depthwise=cfg['bk_depthwise']
  97. )
  98. feat_dims = backbone.feat_dims[-3:]
  99. # Model name
  100. width, depth, ratio = cfg['width'], cfg['depth'], cfg['ratio']
  101. model_name = "{}" if not cfg['bk_pretrained_mae'] else "mae_{}"
  102. if width == 0.25 and depth == 0.34 and ratio == 2.0:
  103. model_name = model_name.format("rtcnet_n")
  104. elif width == 0.50 and depth == 0.34 and ratio == 2.0:
  105. model_name = model_name.format("rtcnet_s")
  106. elif width == 0.75 and depth == 0.67 and ratio == 1.5:
  107. model_name = model_name.format("rtcnet_m")
  108. elif width == 1.0 and depth == 1.0 and ratio == 1.0:
  109. model_name = model_name.format("rtcnet_l")
  110. elif width == 1.25 and depth == 1.34 and ratio == 1.0:
  111. model_name = model_name.format("rtcnet_x")
  112. else:
  113. raise NotImplementedError("No such model size : width={}, depth={}, ratio={}. ".format(width, depth, ratio))
  114. # Load pretrained weight
  115. if pretrained:
  116. backbone = load_pretrained_weight(backbone, model_name)
  117. return backbone, feat_dims
  118. ## Load pretrained weight
  119. def load_pretrained_weight(model, model_name):
  120. # Load pretrained weight
  121. url = model_urls[model_name]
  122. if url is not None:
  123. print('Loading pretrained weight ...')
  124. checkpoint = torch.hub.load_state_dict_from_url(
  125. url=url, map_location="cpu", check_hash=True)
  126. # checkpoint state dict
  127. checkpoint_state_dict = checkpoint.pop("model")
  128. # model state dict
  129. model_state_dict = model.state_dict()
  130. # check
  131. for k in list(checkpoint_state_dict.keys()):
  132. if k in model_state_dict:
  133. shape_model = tuple(model_state_dict[k].shape)
  134. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  135. if shape_model != shape_checkpoint:
  136. checkpoint_state_dict.pop(k)
  137. else:
  138. checkpoint_state_dict.pop(k)
  139. print('Unused key: ', k)
  140. # load the weight
  141. model.load_state_dict(checkpoint_state_dict)
  142. else:
  143. print('No backbone pretrained for {}.'.format(model_name))
  144. return model
  145. if __name__ == '__main__':
  146. import time
  147. from thop import profile
  148. cfg = {
  149. 'bk_pretrained': True,
  150. 'bk_pretrained_mae': False,
  151. 'bk_act': 'silu',
  152. 'bk_norm': 'BN',
  153. 'bk_depthwise': False,
  154. 'width': 0.25,
  155. 'depth': 0.34,
  156. 'ratio': 2.0,
  157. }
  158. model, feats = build_backbone(cfg, pretrained=cfg['bk_pretrained'])
  159. x = torch.randn(1, 3, 640, 640)
  160. t0 = time.time()
  161. outputs = model(x)
  162. t1 = time.time()
  163. print('Time: ', t1 - t0)
  164. for out in outputs:
  165. print(out.shape)
  166. x = torch.randn(1, 3, 640, 640)
  167. print('==============================')
  168. flops, params = profile(model, inputs=(x, ), verbose=False)
  169. print('==============================')
  170. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  171. print('Params : {:.2f} M'.format(params / 1e6))