yolox2_head.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolox2_basic import Conv
  5. except:
  6. from yolox2_basic import Conv
  7. class DecoupledHead(nn.Module):
  8. def __init__(self, cfg, in_dim, out_dim, num_classes=80):
  9. super().__init__()
  10. print('==============================')
  11. print('Head: Decoupled Head')
  12. # --------- Basic Parameters ----------
  13. self.in_dim = in_dim
  14. self.num_classes = num_classes
  15. self.num_cls_head=cfg['num_cls_head']
  16. self.num_reg_head=cfg['num_reg_head']
  17. # --------- Network Parameters ----------
  18. ## cls head
  19. cls_feats = []
  20. self.cls_out_dim = out_dim
  21. for i in range(cfg['num_cls_head']):
  22. if i == 0:
  23. cls_feats.append(
  24. Conv(in_dim, self.cls_out_dim, k=3, p=1, s=1,
  25. act_type=cfg['head_act'],
  26. norm_type=cfg['head_norm'],
  27. depthwise=cfg['head_depthwise'])
  28. )
  29. else:
  30. cls_feats.append(
  31. Conv(self.cls_out_dim, self.cls_out_dim, k=3, p=1, s=1,
  32. act_type=cfg['head_act'],
  33. norm_type=cfg['head_norm'],
  34. depthwise=cfg['head_depthwise'])
  35. )
  36. ## reg head
  37. reg_feats = []
  38. self.reg_out_dim = out_dim
  39. for i in range(cfg['num_reg_head']):
  40. if i == 0:
  41. reg_feats.append(
  42. Conv(in_dim, self.reg_out_dim, k=3, p=1, s=1,
  43. act_type=cfg['head_act'],
  44. norm_type=cfg['head_norm'],
  45. depthwise=cfg['head_depthwise'])
  46. )
  47. else:
  48. reg_feats.append(
  49. Conv(self.reg_out_dim, self.reg_out_dim, k=3, p=1, s=1,
  50. act_type=cfg['head_act'],
  51. norm_type=cfg['head_norm'],
  52. depthwise=cfg['head_depthwise'])
  53. )
  54. self.cls_feats = nn.Sequential(*cls_feats)
  55. self.reg_feats = nn.Sequential(*reg_feats)
  56. ## Pred
  57. self.obj_pred = nn.Conv2d(self.cls_out_dim, 1, kernel_size=1)
  58. self.cls_pred = nn.Conv2d(self.cls_out_dim, num_classes, kernel_size=1)
  59. self.reg_pred = nn.Conv2d(self.reg_out_dim, 4, kernel_size=1)
  60. def forward(self, x):
  61. """
  62. in_feats: (Tensor) [B, C, H, W]
  63. """
  64. cls_feats = self.cls_feats(x)
  65. reg_feats = self.reg_feats(x)
  66. obj_pred = self.obj_pred(reg_feats)
  67. cls_pred = self.cls_pred(cls_feats)
  68. reg_pred = self.reg_pred(reg_feats)
  69. return obj_pred, cls_pred, reg_pred
  70. # build detection head
  71. def build_head(cfg, in_dim, out_dim, num_classes=80):
  72. if cfg['head'] == 'decoupled_head':
  73. head = DecoupledHead(cfg, in_dim, out_dim, num_classes)
  74. return head
  75. if __name__ == '__main__':
  76. import time
  77. from thop import profile
  78. cfg = {
  79. 'head': 'decoupled_head',
  80. 'num_cls_head': 2,
  81. 'num_reg_head': 2,
  82. 'head_act': 'silu',
  83. 'head_norm': 'BN',
  84. 'head_depthwise': False,
  85. 'reg_max': 16,
  86. }
  87. fpn_dims = [256, 512, 512]
  88. # Head-1
  89. model = build_head(cfg, 256, fpn_dims, num_classes=80)
  90. x = torch.randn(1, 256, 80, 80)
  91. t0 = time.time()
  92. outputs = model(x)
  93. t1 = time.time()
  94. print('Time: ', t1 - t0)
  95. # for out in outputs:
  96. # print(out.shape)
  97. print('==============================')
  98. flops, params = profile(model, inputs=(x, ), verbose=False)
  99. print('==============================')
  100. print('Head-1: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  101. print('Head-1: Params : {:.2f} M'.format(params / 1e6))
  102. # Head-2
  103. model = build_head(cfg, 512, fpn_dims, num_classes=80)
  104. x = torch.randn(1, 512, 40, 40)
  105. t0 = time.time()
  106. outputs = model(x)
  107. t1 = time.time()
  108. print('Time: ', t1 - t0)
  109. # for out in outputs:
  110. # print(out.shape)
  111. print('==============================')
  112. flops, params = profile(model, inputs=(x, ), verbose=False)
  113. print('==============================')
  114. print('Head-2: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  115. print('Head-2: Params : {:.2f} M'.format(params / 1e6))
  116. # Head-3
  117. model = build_head(cfg, 512, fpn_dims, num_classes=80)
  118. x = torch.randn(1, 512, 20, 20)
  119. t0 = time.time()
  120. outputs = model(x)
  121. t1 = time.time()
  122. print('Time: ', t1 - t0)
  123. # for out in outputs:
  124. # print(out.shape)
  125. print('==============================')
  126. flops, params = profile(model, inputs=(x, ), verbose=False)
  127. print('==============================')
  128. print('Head-3: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  129. print('Head-3: Params : {:.2f} M'.format(params / 1e6))