ctrnet_head.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .ctrnet_basic import Conv
  5. except:
  6. from ctrnet_basic import Conv
  7. def build_det_head(cfg, in_dim, out_dim):
  8. head = SDetHead(in_dim = in_dim,
  9. cls_head_dim = out_dim,
  10. reg_head_dim = out_dim,
  11. num_cls_head = cfg['num_cls_head'],
  12. num_reg_head = cfg['num_reg_head'],
  13. act_type = cfg['head_act'],
  14. norm_type = cfg['head_norm'],
  15. depthwise = cfg['head_depthwise']
  16. )
  17. return head
  18. # ---------------------------- Detection Head ----------------------------
  19. ## Single-level Detection Head
  20. class SDetHead(nn.Module):
  21. def __init__(self,
  22. in_dim :int = 256,
  23. cls_head_dim :int = 256,
  24. reg_head_dim :int = 256,
  25. num_cls_head :int = 2,
  26. num_reg_head :int = 2,
  27. act_type :str = "silu",
  28. norm_type :str = "BN",
  29. depthwise :bool = False):
  30. super().__init__()
  31. # --------- Basic Parameters ----------
  32. self.in_dim = in_dim
  33. self.num_cls_head = num_cls_head
  34. self.num_reg_head = num_reg_head
  35. self.act_type = act_type
  36. self.norm_type = norm_type
  37. self.depthwise = depthwise
  38. # --------- Network Parameters ----------
  39. ## cls head
  40. cls_feats = []
  41. self.cls_head_dim = cls_head_dim
  42. for i in range(num_cls_head):
  43. if i == 0:
  44. cls_feats.append(
  45. Conv(in_dim, self.cls_head_dim, k=3, p=1, s=1,
  46. act_type=act_type,
  47. norm_type=norm_type,
  48. depthwise=depthwise)
  49. )
  50. else:
  51. cls_feats.append(
  52. Conv(self.cls_head_dim, self.cls_head_dim, k=3, p=1, s=1,
  53. act_type=act_type,
  54. norm_type=norm_type,
  55. depthwise=depthwise)
  56. )
  57. ## reg head
  58. reg_feats = []
  59. self.reg_head_dim = reg_head_dim
  60. for i in range(num_reg_head):
  61. if i == 0:
  62. reg_feats.append(
  63. Conv(in_dim, self.reg_head_dim, k=3, p=1, s=1,
  64. act_type=act_type,
  65. norm_type=norm_type,
  66. depthwise=depthwise)
  67. )
  68. else:
  69. reg_feats.append(
  70. Conv(self.reg_head_dim, self.reg_head_dim, k=3, p=1, s=1,
  71. act_type=act_type,
  72. norm_type=norm_type,
  73. depthwise=depthwise)
  74. )
  75. self.cls_feats = nn.Sequential(*cls_feats)
  76. self.reg_feats = nn.Sequential(*reg_feats)
  77. self.init_weights()
  78. def init_weights(self):
  79. """Initialize the parameters."""
  80. for m in self.modules():
  81. if isinstance(m, torch.nn.Conv2d):
  82. # In order to be consistent with the source code,
  83. # reset the Conv2d initialization parameters
  84. m.reset_parameters()
  85. def forward(self, x):
  86. """
  87. in_feats: (Tensor) [B, C, H, W]
  88. """
  89. cls_feats = self.cls_feats(x)
  90. reg_feats = self.reg_feats(x)
  91. outputs = {
  92. "cls_feat": cls_feats,
  93. "reg_feat": reg_feats
  94. }
  95. return outputs
  96. if __name__ == '__main__':
  97. import time
  98. from thop import profile
  99. cfg = {
  100. 'head': 'decoupled_head',
  101. 'num_cls_head': 2,
  102. 'num_reg_head': 2,
  103. 'head_act': 'silu',
  104. 'head_norm': 'BN',
  105. 'head_depthwise': False,
  106. 'reg_max': 16,
  107. }
  108. fpn_dims = [256, 256, 256]
  109. out_dim = 256
  110. # Head-1
  111. model = build_det_head(cfg, fpn_dims, out_dim, num_levels=3)
  112. print(model)
  113. fpn_feats = [torch.randn(1, fpn_dims[0], 80, 80), torch.randn(1, fpn_dims[1], 40, 40), torch.randn(1, fpn_dims[2], 20, 20)]
  114. t0 = time.time()
  115. outputs = model(fpn_feats)
  116. t1 = time.time()
  117. print('Time: ', t1 - t0)
  118. # for out in outputs:
  119. # print(out.shape)
  120. print('==============================')
  121. flops, params = profile(model, inputs=(fpn_feats, ), verbose=False)
  122. print('==============================')
  123. print('Head-1: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  124. print('Head-1: Params : {:.2f} M'.format(params / 1e6))