rtcdet_head.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .rtcdet_basic import Conv
  5. except:
  6. from rtcdet_basic import Conv
  7. def build_det_head(cfg, in_dims, out_dim, num_levels=3):
  8. head = MDetHead(cfg, in_dims, out_dim, num_levels)
  9. return head
  10. def build_seg_head(cfg, in_dims, out_dim):
  11. return MaskHead()
  12. def build_pose_head(cfg, in_dims, out_dim):
  13. return PoseHead()
  14. # ---------------------------- Detection Head ----------------------------
  15. ## Single-level Detection Head
  16. class SDetHead(nn.Module):
  17. def __init__(self,
  18. in_dim :int = 256,
  19. cls_head_dim :int = 256,
  20. reg_head_dim :int = 256,
  21. num_cls_head :int = 2,
  22. num_reg_head :int = 2,
  23. act_type :str = "silu",
  24. norm_type :str = "BN",
  25. depthwise :bool = False):
  26. super().__init__()
  27. # --------- Basic Parameters ----------
  28. self.in_dim = in_dim
  29. self.num_cls_head = num_cls_head
  30. self.num_reg_head = num_reg_head
  31. self.act_type = act_type
  32. self.norm_type = norm_type
  33. self.depthwise = depthwise
  34. # --------- Network Parameters ----------
  35. ## cls head
  36. cls_feats = []
  37. self.cls_head_dim = cls_head_dim
  38. for i in range(num_cls_head):
  39. if i == 0:
  40. cls_feats.append(
  41. Conv(in_dim, self.cls_head_dim, k=3, p=1, s=1,
  42. act_type=act_type,
  43. norm_type=norm_type,
  44. depthwise=depthwise)
  45. )
  46. else:
  47. cls_feats.append(
  48. Conv(self.cls_head_dim, self.cls_head_dim, k=3, p=1, s=1,
  49. act_type=act_type,
  50. norm_type=norm_type,
  51. depthwise=depthwise)
  52. )
  53. ## reg head
  54. reg_feats = []
  55. self.reg_head_dim = reg_head_dim
  56. for i in range(num_reg_head):
  57. if i == 0:
  58. reg_feats.append(
  59. Conv(in_dim, self.reg_head_dim, k=3, p=1, s=1,
  60. act_type=act_type,
  61. norm_type=norm_type,
  62. depthwise=depthwise)
  63. )
  64. else:
  65. reg_feats.append(
  66. Conv(self.reg_head_dim, self.reg_head_dim, k=3, p=1, s=1,
  67. act_type=act_type,
  68. norm_type=norm_type,
  69. depthwise=depthwise)
  70. )
  71. self.cls_feats = nn.Sequential(*cls_feats)
  72. self.reg_feats = nn.Sequential(*reg_feats)
  73. self.init_weights()
  74. def init_weights(self):
  75. """Initialize the parameters."""
  76. for m in self.modules():
  77. if isinstance(m, torch.nn.Conv2d):
  78. # In order to be consistent with the source code,
  79. # reset the Conv2d initialization parameters
  80. m.reset_parameters()
  81. def forward(self, x):
  82. """
  83. in_feats: (Tensor) [B, C, H, W]
  84. """
  85. cls_feats = self.cls_feats(x)
  86. reg_feats = self.reg_feats(x)
  87. return cls_feats, reg_feats
  88. ## Multi-level Detection Head
  89. class MDetHead(nn.Module):
  90. def __init__(self, cfg, in_dims, out_dim, num_levels=3):
  91. super().__init__()
  92. ## ----------- Network Parameters -----------
  93. self.multi_level_heads = nn.ModuleList(
  94. [SDetHead(in_dim=in_dims[level],
  95. cls_head_dim = out_dim,
  96. reg_head_dim = out_dim,
  97. num_cls_head = cfg['num_cls_head'],
  98. num_reg_head = cfg['num_reg_head'],
  99. act_type = cfg['head_act'],
  100. norm_type = cfg['head_norm'],
  101. depthwise = cfg['head_depthwise'])
  102. for level in range(num_levels)
  103. ])
  104. # --------- Basic Parameters ----------
  105. self.in_dims = in_dims
  106. self.cls_head_dim = self.multi_level_heads[0].cls_head_dim
  107. self.reg_head_dim = self.multi_level_heads[0].reg_head_dim
  108. def forward(self, feats):
  109. """
  110. feats: List[(Tensor)] [[B, C, H, W], ...]
  111. """
  112. cls_feats = []
  113. reg_feats = []
  114. for feat, head in zip(feats, self.multi_level_heads):
  115. # ---------------- Pred ----------------
  116. cls_feat, reg_feat = head(feat)
  117. cls_feats.append(cls_feat)
  118. reg_feats.append(reg_feat)
  119. outputs = {
  120. "cls_feat": cls_feats,
  121. "reg_feat": reg_feats
  122. }
  123. return outputs
  124. # ---------------------------- Segmentation Head ----------------------------
  125. class MaskHead(nn.Module):
  126. def __init__(self, *args, **kwargs) -> None:
  127. super().__init__(*args, **kwargs)
  128. def forward(self, x):
  129. return
  130. # ---------------------------- Human-Pose Head ----------------------------
  131. class PoseHead(nn.Module):
  132. def __init__(self, *args, **kwargs) -> None:
  133. super().__init__(*args, **kwargs)
  134. def forward(self, x):
  135. return
  136. if __name__ == '__main__':
  137. import time
  138. from thop import profile
  139. cfg = {
  140. 'head': 'decoupled_head',
  141. 'num_cls_head': 2,
  142. 'num_reg_head': 2,
  143. 'head_act': 'silu',
  144. 'head_norm': 'BN',
  145. 'head_depthwise': False,
  146. 'reg_max': 16,
  147. }
  148. fpn_dims = [256, 256, 256]
  149. out_dim = 256
  150. # Head-1
  151. model = build_det_head(cfg, fpn_dims, out_dim, num_levels=3)
  152. print(model)
  153. fpn_feats = [torch.randn(1, fpn_dims[0], 80, 80), torch.randn(1, fpn_dims[1], 40, 40), torch.randn(1, fpn_dims[2], 20, 20)]
  154. t0 = time.time()
  155. outputs = model(fpn_feats)
  156. t1 = time.time()
  157. print('Time: ', t1 - t0)
  158. # for out in outputs:
  159. # print(out.shape)
  160. print('==============================')
  161. flops, params = profile(model, inputs=(fpn_feats, ), verbose=False)
  162. print('==============================')
  163. print('Head-1: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  164. print('Head-1: Params : {:.2f} M'.format(params / 1e6))