rtcdet_v2_backbone.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .rtcdet_v2_basic import Conv, CSPFasterStage, DSBlock
  5. except:
  6. from rtcdet_v2_basic import Conv, CSPFasterStage, DSBlock
  7. model_urls = {
  8. 'fasternet_n': None,
  9. 'fasternet_t': None,
  10. 'fasternet_s': None,
  11. 'fasternet_m': None,
  12. 'fasternet_l': None,
  13. 'fasternet_x': None,
  14. }
  15. # ---------------------------- Backbones ----------------------------
  16. # Modified FasterNet
  17. class FasterConvNet(nn.Module):
  18. def __init__(self, width=1.0, depth=1.0, act_type='silu', norm_type='BN', depthwise=False):
  19. super(FasterConvNet, self).__init__()
  20. # ------------------ Basic parameters ------------------
  21. ## scale factor
  22. self.width = width
  23. self.depth = depth
  24. ## pyramid feats
  25. self.base_dims = [64, 128, 256, 512, 1024]
  26. self.feat_dims = [round(dim * width) for dim in self.base_dims]
  27. ## block depth
  28. self.base_blocks = [3, 9, 9, 3]
  29. self.feat_blocks = [round(nblock * depth) for nblock in self.base_blocks]
  30. ## nonlinear
  31. self.act_type = act_type
  32. self.norm_type = norm_type
  33. self.depthwise = depthwise
  34. # ------------------ Network parameters ------------------
  35. ## P1/2
  36. self.layer_1 = nn.Sequential(
  37. Conv(3, self.feat_dims[0], k=6, p=2, s=2, act_type=self.act_type, norm_type=self.norm_type),
  38. Conv(self.feat_dims[0], self.feat_dims[0], k=3, p=1, act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise),
  39. )
  40. ## P2/4
  41. self.layer_2 = nn.Sequential(
  42. Conv(self.feat_dims[0], self.feat_dims[1], k=3, p=1, s=2, act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise),
  43. CSPFasterStage(self.feat_dims[1], self.feat_dims[1], self.feat_blocks[0], 3, True, self.act_type, self.norm_type)
  44. )
  45. ## P3/8
  46. self.layer_3 = nn.Sequential(
  47. DSBlock(self.feat_dims[1], self.feat_dims[2], act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise),
  48. CSPFasterStage(self.feat_dims[2], self.feat_dims[2], self.feat_blocks[1], 3, True, self.act_type, self.norm_type)
  49. )
  50. ## P4/16
  51. self.layer_4 = nn.Sequential(
  52. DSBlock(self.feat_dims[2], self.feat_dims[3], act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise),
  53. CSPFasterStage(self.feat_dims[3], self.feat_dims[3], self.feat_blocks[2], 3, True, self.act_type, self.norm_type)
  54. )
  55. ## P5/32
  56. self.layer_5 = nn.Sequential(
  57. DSBlock(self.feat_dims[3], self.feat_dims[4], act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise),
  58. CSPFasterStage(self.feat_dims[4], self.feat_dims[4], self.feat_blocks[3], 3, True, self.act_type, self.norm_type)
  59. )
  60. def forward(self, x):
  61. c1 = self.layer_1(x)
  62. c2 = self.layer_2(c1)
  63. c3 = self.layer_3(c2)
  64. c4 = self.layer_4(c3)
  65. c5 = self.layer_5(c4)
  66. outputs = [c3, c4, c5]
  67. return outputs
  68. # ---------------------------- Functions ----------------------------
  69. ## load pretrained weight
  70. def load_weight(model, model_name):
  71. # load weight
  72. print('Loading pretrained weight ...')
  73. url = model_urls[model_name]
  74. if url is not None:
  75. checkpoint = torch.hub.load_state_dict_from_url(
  76. url=url, map_location="cpu", check_hash=True)
  77. # checkpoint state dict
  78. checkpoint_state_dict = checkpoint.pop("model")
  79. # model state dict
  80. model_state_dict = model.state_dict()
  81. # check
  82. for k in list(checkpoint_state_dict.keys()):
  83. if k in model_state_dict:
  84. shape_model = tuple(model_state_dict[k].shape)
  85. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  86. if shape_model != shape_checkpoint:
  87. checkpoint_state_dict.pop(k)
  88. else:
  89. checkpoint_state_dict.pop(k)
  90. print(k)
  91. model.load_state_dict(checkpoint_state_dict)
  92. else:
  93. print('No pretrained for {}'.format(model_name))
  94. return model
  95. ## build MCNet
  96. def build_backbone(cfg, pretrained=False):
  97. # model
  98. backbone = FasterConvNet(cfg['width'], cfg['depth'], cfg['bk_act'], cfg['bk_norm'], cfg['bk_depthwise'])
  99. # check whether to load imagenet pretrained weight
  100. if pretrained:
  101. if cfg['width'] == 0.25 and cfg['depth'] == 0.34:
  102. backbone = load_weight(backbone, model_name='fasternet_n')
  103. elif cfg['width'] == 0.375 and cfg['depth'] == 0.34:
  104. backbone = load_weight(backbone, model_name='fasternet_t')
  105. elif cfg['width'] == 0.5 and cfg['depth'] == 0.34:
  106. backbone = load_weight(backbone, model_name='fasternet_s')
  107. elif cfg['width'] == 0.75 and cfg['depth'] == 0.67:
  108. backbone = load_weight(backbone, model_name='fasternet_m')
  109. elif cfg['width'] == 1.0 and cfg['depth'] == 1.0:
  110. backbone = load_weight(backbone, model_name='fasternet_l')
  111. elif cfg['width'] == 1.25 and cfg['depth'] == 1.34:
  112. backbone = load_weight(backbone, model_name='fasternet_x')
  113. feat_dims = backbone.feat_dims[-3:]
  114. return backbone, feat_dims
  115. if __name__ == '__main__':
  116. import time
  117. from thop import profile
  118. cfg = {
  119. ## Backbone
  120. 'backbone': 'mcnet',
  121. 'pretrained': True,
  122. 'bk_act': 'silu',
  123. 'bk_norm': 'BN',
  124. 'bk_depthwise': False,
  125. 'width': 1.0,
  126. 'depth': 1.0,
  127. 'stride': [8, 16, 32], # P3, P4, P5
  128. 'max_stride': 32,
  129. }
  130. model, feats = build_backbone(cfg)
  131. x = torch.randn(1, 3, 640, 640)
  132. t0 = time.time()
  133. outputs = model(x)
  134. t1 = time.time()
  135. print('Time: ', t1 - t0)
  136. for out in outputs:
  137. print(out.shape)
  138. print('==============================')
  139. flops, params = profile(model, inputs=(x, ), verbose=False)
  140. print('==============================')
  141. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  142. print('Params : {:.2f} M'.format(params / 1e6))