fpn.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from typing import List
  5. try:
  6. from .basic import BasicConv, RTCBlock
  7. from .transformer import TransformerEncoder
  8. except:
  9. from basic import BasicConv, RTCBlock
  10. from transformer import TransformerEncoder
  11. # Build PaFPN
  12. def build_fpn(cfg, in_dims, out_dim):
  13. if cfg['fpn'] == 'hybrid_encoder':
  14. return HybridEncoder(in_dims = in_dims,
  15. out_dim = out_dim,
  16. num_blocks = cfg['fpn_num_blocks'],
  17. act_type = cfg['fpn_act'],
  18. norm_type = cfg['fpn_norm'],
  19. depthwise = cfg['fpn_depthwise'],
  20. num_heads = cfg['en_num_heads'],
  21. num_layers = cfg['en_num_layers'],
  22. ffn_dim = cfg['en_ffn_dim'],
  23. dropout = cfg['en_dropout'],
  24. pe_temperature = cfg['pe_temperature'],
  25. en_act_type = cfg['en_act'],
  26. )
  27. else:
  28. raise NotImplementedError("Unknown PaFPN: <{}>".format(cfg['fpn']))
  29. # ----------------- Feature Pyramid Network -----------------
  30. ## Hybrid Encoder (Transformer encoder + Convolutional PaFPN)
  31. class HybridEncoder(nn.Module):
  32. def __init__(self,
  33. in_dims :List = [256, 512, 1024],
  34. out_dim :int = 256,
  35. num_blocks :int = 3,
  36. act_type :str = 'silu',
  37. norm_type :str = 'BN',
  38. depthwise :bool = False,
  39. # Transformer's parameters
  40. num_heads :int = 8,
  41. num_layers :int = 1,
  42. ffn_dim :int = 1024,
  43. dropout :float = 0.1,
  44. pe_temperature :float = 10000.,
  45. en_act_type :str = 'gelu'
  46. ) -> None:
  47. super(HybridEncoder, self).__init__()
  48. print('==============================')
  49. print('FPN: {}'.format("RTC-PaFPN"))
  50. # ---------------- Basic parameters ----------------
  51. self.in_dims = in_dims
  52. self.out_dim = out_dim
  53. self.out_dims = [self.out_dim] * len(in_dims)
  54. self.num_heads = num_heads
  55. self.num_layers = num_layers
  56. self.ffn_dim = ffn_dim
  57. c3, c4, c5 = in_dims
  58. # ---------------- Input projs ----------------
  59. self.reduce_layer_1 = BasicConv(c5, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  60. self.reduce_layer_2 = BasicConv(c4, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  61. self.reduce_layer_3 = BasicConv(c3, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  62. # ---------------- Downsample ----------------
  63. self.dowmsample_layer_1 = BasicConv(self.out_dim, self.out_dim,
  64. kernel_size=3, padding=1, stride=2,
  65. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  66. self.dowmsample_layer_2 = BasicConv(self.out_dim, self.out_dim,
  67. kernel_size=3, padding=1, stride=2,
  68. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  69. # ---------------- Transformer Encoder ----------------
  70. self.transformer_encoder = TransformerEncoder(d_model = self.out_dim,
  71. num_heads = num_heads,
  72. num_layers = num_layers,
  73. ffn_dim = ffn_dim,
  74. pe_temperature = pe_temperature,
  75. dropout = dropout,
  76. act_type = en_act_type
  77. )
  78. # ---------------- Top dwon FPN ----------------
  79. ## P5 -> P4
  80. self.top_down_layer_1 = RTCBlock(in_dim = self.out_dim * 2,
  81. out_dim = self.out_dim,
  82. num_blocks = num_blocks,
  83. shortcut = False,
  84. act_type = act_type,
  85. norm_type = norm_type,
  86. depthwise = depthwise,
  87. )
  88. ## P4 -> P3
  89. self.top_down_layer_2 = RTCBlock(in_dim = self.out_dim * 2,
  90. out_dim = self.out_dim,
  91. num_blocks = num_blocks,
  92. shortcut = False,
  93. act_type = act_type,
  94. norm_type = norm_type,
  95. depthwise = depthwise,
  96. )
  97. # ---------------- Bottom up PAN----------------
  98. ## P3 -> P4
  99. self.bottom_up_layer_1 = RTCBlock(in_dim = self.out_dim * 2,
  100. out_dim = self.out_dim,
  101. num_blocks = num_blocks,
  102. shortcut = False,
  103. act_type = act_type,
  104. norm_type = norm_type,
  105. depthwise = depthwise,
  106. )
  107. ## P4 -> P5
  108. self.bottom_up_layer_2 = RTCBlock(in_dim = self.out_dim * 2,
  109. out_dim = self.out_dim,
  110. num_blocks = num_blocks,
  111. shortcut = False,
  112. act_type = act_type,
  113. norm_type = norm_type,
  114. depthwise = depthwise,
  115. )
  116. self.init_weights()
  117. def init_weights(self):
  118. """Initialize the parameters."""
  119. for m in self.modules():
  120. if isinstance(m, torch.nn.Conv2d):
  121. # In order to be consistent with the source code,
  122. # reset the Conv2d initialization parameters
  123. m.reset_parameters()
  124. def forward(self, features):
  125. c3, c4, c5 = features
  126. # -------- Input projs --------
  127. p5 = self.reduce_layer_1(c5)
  128. p4 = self.reduce_layer_2(c4)
  129. p3 = self.reduce_layer_3(c3)
  130. # -------- Transformer encoder --------
  131. p5 = self.transformer_encoder(p5)
  132. # -------- Top down FPN --------
  133. p5_up = F.interpolate(p5, scale_factor=2.0)
  134. p4 = self.top_down_layer_1(torch.cat([p4, p5_up], dim=1))
  135. p4_up = F.interpolate(p4, scale_factor=2.0)
  136. p3 = self.top_down_layer_2(torch.cat([p3, p4_up], dim=1))
  137. # -------- Bottom up PAN --------
  138. p3_ds = self.dowmsample_layer_1(p3)
  139. p4 = self.bottom_up_layer_1(torch.cat([p4, p3_ds], dim=1))
  140. p4_ds = self.dowmsample_layer_2(p4)
  141. p5 = self.bottom_up_layer_2(torch.cat([p5, p4_ds], dim=1))
  142. out_feats = [p3, p4, p5]
  143. return out_feats
  144. if __name__ == '__main__':
  145. import time
  146. from thop import profile
  147. cfg = {
  148. 'fpn': 'hybrid_encoder',
  149. 'fpn_act': 'silu',
  150. 'fpn_norm': 'BN',
  151. 'fpn_depthwise': False,
  152. 'fpn_num_blocks': 3,
  153. 'fpn_expansion': 0.5,
  154. 'en_num_heads': 8,
  155. 'en_num_layers': 1,
  156. 'en_ffn_dim': 1024,
  157. 'en_dropout': 0.0,
  158. 'pe_temperature': 10000.,
  159. 'en_act': 'gelu',
  160. }
  161. fpn_dims = [256, 512, 1024]
  162. out_dim = 256
  163. pyramid_feats = [torch.randn(1, fpn_dims[0], 80, 80), torch.randn(1, fpn_dims[1], 40, 40), torch.randn(1, fpn_dims[2], 20, 20)]
  164. model = build_fpn(cfg, fpn_dims, out_dim)
  165. t0 = time.time()
  166. outputs = model(pyramid_feats)
  167. t1 = time.time()
  168. print('Time: ', t1 - t0)
  169. for out in outputs:
  170. print(out.shape)
  171. print('==============================')
  172. flops, params = profile(model, inputs=(pyramid_feats, ), verbose=False)
  173. print('==============================')
  174. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  175. print('Params : {:.2f} M'.format(params / 1e6))