yolox_plus_head.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolox_plus_basic import Conv
  5. except:
  6. from yolox_plus_basic import Conv
  7. class DecoupledHead(nn.Module):
  8. def __init__(self, cfg, in_dim, out_dim, num_classes=80):
  9. super().__init__()
  10. print('==============================')
  11. print('Head: Decoupled Head')
  12. # --------- Basic Parameters ----------
  13. self.in_dim = in_dim
  14. self.num_classes = num_classes
  15. self.num_cls_head=cfg['num_cls_head']
  16. self.num_reg_head=cfg['num_reg_head']
  17. # --------- Network Parameters ----------
  18. ## cls head
  19. cls_feats = []
  20. self.cls_out_dim = out_dim
  21. for i in range(cfg['num_cls_head']):
  22. if i == 0:
  23. cls_feats.append(
  24. Conv(in_dim, self.cls_out_dim, k=3, p=1, s=1,
  25. act_type=cfg['head_act'],
  26. norm_type=cfg['head_norm'],
  27. depthwise=cfg['head_depthwise'])
  28. )
  29. else:
  30. cls_feats.append(
  31. Conv(self.cls_out_dim, self.cls_out_dim, k=3, p=1, s=1,
  32. act_type=cfg['head_act'],
  33. norm_type=cfg['head_norm'],
  34. depthwise=cfg['head_depthwise'])
  35. )
  36. ## reg head
  37. reg_feats = []
  38. self.reg_out_dim = out_dim
  39. for i in range(cfg['num_reg_head']):
  40. if i == 0:
  41. reg_feats.append(
  42. Conv(in_dim, self.reg_out_dim, k=3, p=1, s=1,
  43. act_type=cfg['head_act'],
  44. norm_type=cfg['head_norm'],
  45. depthwise=cfg['head_depthwise'])
  46. )
  47. else:
  48. reg_feats.append(
  49. Conv(self.reg_out_dim, self.reg_out_dim, k=3, p=1, s=1,
  50. act_type=cfg['head_act'],
  51. norm_type=cfg['head_norm'],
  52. depthwise=cfg['head_depthwise'])
  53. )
  54. self.cls_feats = nn.Sequential(*cls_feats)
  55. self.reg_feats = nn.Sequential(*reg_feats)
  56. ## Pred
  57. self.cls_pred = nn.Conv2d(self.cls_out_dim, num_classes, kernel_size=1)
  58. self.reg_pred = nn.Conv2d(self.reg_out_dim, 4*cfg['reg_max'], kernel_size=1)
  59. def forward(self, x):
  60. """
  61. in_feats: (Tensor) [B, C, H, W]
  62. """
  63. cls_feats = self.cls_feats(x)
  64. reg_feats = self.reg_feats(x)
  65. cls_pred = self.cls_pred(cls_feats)
  66. reg_pred = self.reg_pred(reg_feats)
  67. return cls_pred, reg_pred
  68. # build detection head
  69. def build_head(cfg, in_dim, out_dim, num_classes=80):
  70. if cfg['head'] == 'decoupled_head':
  71. head = DecoupledHead(cfg, in_dim, out_dim, num_classes)
  72. return head
  73. if __name__ == '__main__':
  74. import time
  75. from thop import profile
  76. cfg = {
  77. 'head': 'decoupled_head',
  78. 'num_cls_head': 2,
  79. 'num_reg_head': 2,
  80. 'head_act': 'silu',
  81. 'head_norm': 'BN',
  82. 'head_depthwise': False,
  83. 'reg_max': 16,
  84. }
  85. fpn_dims = [256, 512, 512]
  86. # Head-1
  87. model = build_head(cfg, 256, fpn_dims, num_classes=80)
  88. x = torch.randn(1, 256, 80, 80)
  89. t0 = time.time()
  90. outputs = model(x)
  91. t1 = time.time()
  92. print('Time: ', t1 - t0)
  93. # for out in outputs:
  94. # print(out.shape)
  95. print('==============================')
  96. flops, params = profile(model, inputs=(x, ), verbose=False)
  97. print('==============================')
  98. print('Head-1: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  99. print('Head-1: Params : {:.2f} M'.format(params / 1e6))
  100. # Head-2
  101. model = build_head(cfg, 512, fpn_dims, num_classes=80)
  102. x = torch.randn(1, 512, 40, 40)
  103. t0 = time.time()
  104. outputs = model(x)
  105. t1 = time.time()
  106. print('Time: ', t1 - t0)
  107. # for out in outputs:
  108. # print(out.shape)
  109. print('==============================')
  110. flops, params = profile(model, inputs=(x, ), verbose=False)
  111. print('==============================')
  112. print('Head-2: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  113. print('Head-2: Params : {:.2f} M'.format(params / 1e6))
  114. # Head-3
  115. model = build_head(cfg, 512, fpn_dims, num_classes=80)
  116. x = torch.randn(1, 512, 20, 20)
  117. t0 = time.time()
  118. outputs = model(x)
  119. t1 = time.time()
  120. print('Time: ', t1 - t0)
  121. # for out in outputs:
  122. # print(out.shape)
  123. print('==============================')
  124. flops, params = profile(model, inputs=(x, ), verbose=False)
  125. print('==============================')
  126. print('Head-3: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  127. print('Head-3: Params : {:.2f} M'.format(params / 1e6))