vitdet_decoder.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .basic_modules.basic import BasicConv
  5. except:
  6. from basic_modules.basic import BasicConv
  7. def build_decoder(cfg, in_dims, num_levels=3):
  8. if cfg['decoder'] == "det_decoder":
  9. decoder = MultiDetHead(cfg, in_dims, num_levels)
  10. elif cfg['decoder'] == "seg_decoder":
  11. decoder = MaskHead()
  12. elif cfg['decoder'] == "pos_decoder":
  13. decoder = PoseHead()
  14. return decoder
  15. # ---------------------------- Detection Head ----------------------------
  16. ## Single-level Detection Head
  17. class SingleDetHead(nn.Module):
  18. def __init__(self,
  19. in_dim :int = 256,
  20. cls_head_dim :int = 256,
  21. reg_head_dim :int = 256,
  22. num_cls_head :int = 2,
  23. num_reg_head :int = 2,
  24. act_type :str = "silu",
  25. norm_type :str = "BN",
  26. ):
  27. super().__init__()
  28. # --------- Basic Parameters ----------
  29. self.in_dim = in_dim
  30. self.num_cls_head = num_cls_head
  31. self.num_reg_head = num_reg_head
  32. self.act_type = act_type
  33. self.norm_type = norm_type
  34. # --------- Network Parameters ----------
  35. ## cls head
  36. cls_feats = []
  37. self.cls_head_dim = cls_head_dim
  38. for i in range(num_cls_head):
  39. if i == 0:
  40. cls_feats.append(
  41. BasicConv(in_dim, self.cls_head_dim,
  42. kernel_size=3, padding=1, stride=1,
  43. act_type=act_type, norm_type=norm_type)
  44. )
  45. else:
  46. cls_feats.append(
  47. BasicConv(self.cls_head_dim, self.cls_head_dim,
  48. kernel_size=3, padding=1, stride=1,
  49. act_type=act_type, norm_type=norm_type)
  50. )
  51. ## reg head
  52. reg_feats = []
  53. self.reg_head_dim = reg_head_dim
  54. for i in range(num_reg_head):
  55. if i == 0:
  56. cls_feats.append(
  57. BasicConv(in_dim, self.reg_head_dim,
  58. kernel_size=3, padding=1, stride=1,
  59. act_type=act_type, norm_type=norm_type)
  60. )
  61. else:
  62. cls_feats.append(
  63. BasicConv(self.reg_head_dim, self.reg_head_dim,
  64. kernel_size=3, padding=1, stride=1,
  65. act_type=act_type, norm_type=norm_type)
  66. )
  67. self.cls_feats = nn.Sequential(*cls_feats)
  68. self.reg_feats = nn.Sequential(*reg_feats)
  69. self.init_weights()
  70. def init_weights(self):
  71. """Initialize the parameters."""
  72. for m in self.modules():
  73. if isinstance(m, torch.nn.Conv2d):
  74. # In order to be consistent with the source code,
  75. # reset the Conv2d initialization parameters
  76. m.reset_parameters()
  77. def forward(self, x):
  78. """
  79. in_feats: (Tensor) [B, C, H, W]
  80. """
  81. cls_feats = self.cls_feats(x)
  82. reg_feats = self.reg_feats(x)
  83. return cls_feats, reg_feats
  84. ## Multi-level Detection Head
  85. class MultiDetHead(nn.Module):
  86. def __init__(self, cfg, in_dims, num_levels=3):
  87. super().__init__()
  88. ## ----------- Network Parameters -----------
  89. self.multi_level_heads = nn.ModuleList(
  90. [SingleDetHead(in_dim = in_dims[level],
  91. cls_head_dim = cfg['hidden_dim'],
  92. reg_head_dim = cfg['hidden_dim'],
  93. num_cls_head = cfg['de_num_cls_layers'],
  94. num_reg_head = cfg['de_num_reg_layers'],
  95. act_type = cfg['de_act'],
  96. norm_type = cfg['de_norm'],
  97. )
  98. for level in range(num_levels)
  99. ])
  100. # --------- Basic Parameters ----------
  101. self.in_dims = in_dims
  102. self.cls_head_dim = self.multi_level_heads[0].cls_head_dim
  103. self.reg_head_dim = self.multi_level_heads[0].reg_head_dim
  104. def forward(self, feats):
  105. """
  106. feats: List[(Tensor)] [[B, C, H, W], ...]
  107. """
  108. cls_feats = []
  109. reg_feats = []
  110. for feat, head in zip(feats, self.multi_level_heads):
  111. # ---------------- Pred ----------------
  112. cls_feat, reg_feat = head(feat)
  113. cls_feats.append(cls_feat)
  114. reg_feats.append(reg_feat)
  115. outputs = {
  116. "cls_feat": cls_feats,
  117. "reg_feat": reg_feats
  118. }
  119. return outputs
  120. # ---------------------------- Segmentation Head ----------------------------
  121. class MaskHead(nn.Module):
  122. def __init__(self, *args, **kwargs) -> None:
  123. super().__init__(*args, **kwargs)
  124. def forward(self, x):
  125. return
  126. # ---------------------------- Human-Pose Head ----------------------------
  127. class PoseHead(nn.Module):
  128. def __init__(self, *args, **kwargs) -> None:
  129. super().__init__(*args, **kwargs)
  130. def forward(self, x):
  131. return
  132. if __name__ == '__main__':
  133. import time
  134. from thop import profile
  135. cfg = {
  136. 'width': 1.0,
  137. 'depth': 1.0,
  138. # Decoder parameters
  139. 'hidden_dim': 256,
  140. 'decoder': 'det_decoder',
  141. 'de_num_cls_layers': 2,
  142. 'de_num_reg_layers': 2,
  143. 'de_act': 'silu',
  144. 'de_norm': 'BN',
  145. }
  146. fpn_dims = [256, 256, 256]
  147. out_dim = 256
  148. # Head-1
  149. model = build_decoder(cfg, fpn_dims, num_levels=3)
  150. print(model)
  151. fpn_feats = [torch.randn(1, fpn_dims[0], 80, 80), torch.randn(1, fpn_dims[1], 40, 40), torch.randn(1, fpn_dims[2], 20, 20)]
  152. t0 = time.time()
  153. outputs = model(fpn_feats)
  154. t1 = time.time()
  155. print('Time: ', t1 - t0)
  156. # for out in outputs:
  157. # print(out.shape)
  158. print('==============================')
  159. flops, params = profile(model, inputs=(fpn_feats, ), verbose=False)
  160. print('==============================')
  161. print('Head-1: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  162. print('Head-1: Params : {:.2f} M'.format(params / 1e6))