yolox_backbone.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolox_basic import Conv, CSPBlock
  5. from .yolox_neck import SPPF
  6. except:
  7. from yolox_basic import Conv, CSPBlock
  8. from yolox_neck import SPPF
  9. # CSPDarkNet
  10. class CSPDarkNet(nn.Module):
  11. def __init__(self, depth=1.0, width=1.0, act_type='silu', norm_type='BN', depthwise=False):
  12. super(CSPDarkNet, self).__init__()
  13. self.feat_dims = [int(256*width), int(512*width), int(1024*width)]
  14. # P1
  15. self.layer_1 = Conv(3, int(64*width), k=6, p=2, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  16. # P2
  17. self.layer_2 = nn.Sequential(
  18. Conv(int(64*width), int(128*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  19. CSPBlock(int(128*width), int(128*width), expand_ratio=0.5, nblocks=int(3*depth),
  20. shortcut=True, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  21. )
  22. # P3
  23. self.layer_3 = nn.Sequential(
  24. Conv(int(128*width), int(256*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  25. CSPBlock(int(256*width), int(256*width), expand_ratio=0.5, nblocks=int(9*depth),
  26. shortcut=True, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  27. )
  28. # P4
  29. self.layer_4 = nn.Sequential(
  30. Conv(int(256*width), int(512*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  31. CSPBlock(int(512*width), int(512*width), expand_ratio=0.5, nblocks=int(9*depth),
  32. shortcut=True, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  33. )
  34. # P5
  35. self.layer_5 = nn.Sequential(
  36. Conv(int(512*width), int(1024*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  37. SPPF(int(1024*width), int(1024*width), expand_ratio=0.5, act_type=act_type, norm_type=norm_type),
  38. CSPBlock(int(1024*width), int(1024*width), expand_ratio=0.5, nblocks=int(3*depth),
  39. shortcut=True, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  40. )
  41. def forward(self, x):
  42. c1 = self.layer_1(x)
  43. c2 = self.layer_2(c1)
  44. c3 = self.layer_3(c2)
  45. c4 = self.layer_4(c3)
  46. c5 = self.layer_5(c4)
  47. outputs = [c3, c4, c5]
  48. return outputs
  49. # ---------------------------- Functions ----------------------------
  50. def build_backbone(cfg):
  51. """Constructs a darknet-53 model.
  52. Args:
  53. pretrained (bool): If True, returns a model pre-trained on ImageNet
  54. """
  55. backbone = CSPDarkNet(cfg['depth'], cfg['width'], cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
  56. feat_dims = backbone.feat_dims
  57. return backbone, feat_dims
  58. if __name__ == '__main__':
  59. import time
  60. from thop import profile
  61. cfg = {
  62. 'pretrained': False,
  63. 'bk_act': 'lrelu',
  64. 'bk_norm': 'BN',
  65. 'bk_dpw': False,
  66. 'p6_feat': False,
  67. 'p7_feat': False,
  68. 'width': 1.0,
  69. 'depth': 1.0,
  70. }
  71. model, feats = build_backbone(cfg)
  72. x = torch.randn(1, 3, 256, 256)
  73. t0 = time.time()
  74. outputs = model(x)
  75. t1 = time.time()
  76. print('Time: ', t1 - t0)
  77. for out in outputs:
  78. print(out.shape)
  79. x = torch.randn(1, 3, 256, 256)
  80. print('==============================')
  81. flops, params = profile(model, inputs=(x, ), verbose=False)
  82. print('==============================')
  83. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  84. print('Params : {:.2f} M'.format(params / 1e6))