fpn.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from typing import List
  5. try:
  6. from .basic import get_clones, BasicConv, RTCBlock, TransformerEncoder
  7. except:
  8. from basic import get_clones, BasicConv, RTCBlock, TransformerEncoder
  9. # Build PaFPN
  10. def build_fpn(cfg, in_dims, out_dim):
  11. if cfg['fpn'] == 'hybrid_encoder':
  12. return HybridEncoder(in_dims = in_dims,
  13. out_dim = out_dim,
  14. width = cfg['width'],
  15. depth = cfg['depth'],
  16. act_type = cfg['fpn_act'],
  17. norm_type = cfg['fpn_norm'],
  18. depthwise = cfg['fpn_depthwise'],
  19. num_heads = cfg['en_num_heads'],
  20. num_layers = cfg['en_num_layers'],
  21. mlp_ratio = cfg['en_mlp_ratio'],
  22. dropout = cfg['en_dropout'],
  23. pe_temperature = cfg['pe_temperature'],
  24. en_act_type = cfg['en_act'],
  25. )
  26. else:
  27. raise NotImplementedError("Unknown PaFPN: <{}>".format(cfg['fpn']))
  28. # ----------------- Feature Pyramid Network -----------------
  29. ## Hybrid Encoder (Transformer encoder + Convolutional PaFPN)
  30. class HybridEncoder(nn.Module):
  31. def __init__(self,
  32. in_dims :List = [256, 512, 512],
  33. out_dim :int = 256,
  34. width :float = 1.0,
  35. depth :float = 1.0,
  36. act_type :str = 'silu',
  37. norm_type :str = 'BN',
  38. depthwise :bool = False,
  39. # Transformer's parameters
  40. num_heads :int = 8,
  41. num_layers :int = 1,
  42. mlp_ratio :float = 4.0,
  43. dropout :float = 0.1,
  44. pe_temperature :float = 10000.,
  45. en_act_type :str = 'gelu'
  46. ) -> None:
  47. super(HybridEncoder, self).__init__()
  48. print('==============================')
  49. print('FPN: {}'.format("RTC-PaFPN"))
  50. # ---------------- Basic parameters ----------------
  51. self.in_dims = in_dims
  52. self.out_dim = round(out_dim * width)
  53. self.width = width
  54. self.depth = depth
  55. self.num_heads = num_heads
  56. self.num_layers = num_layers
  57. self.mlp_ratio = mlp_ratio
  58. c3, c4, c5 = in_dims
  59. # ---------------- Input projs ----------------
  60. self.reduce_layer_1 = BasicConv(c5, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  61. self.reduce_layer_2 = BasicConv(c4, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  62. self.reduce_layer_3 = BasicConv(c3, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  63. # ---------------- Downsample ----------------
  64. self.dowmsample_layer_1 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
  65. self.dowmsample_layer_2 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
  66. # ---------------- Transformer Encoder ----------------
  67. self.transformer_encoder = TransformerEncoder(d_model = self.out_dim,
  68. num_heads = num_heads,
  69. num_layers = num_layers,
  70. mlp_ratio = mlp_ratio,
  71. pe_temperature = pe_temperature,
  72. dropout = dropout,
  73. act_type = en_act_type
  74. )
  75. # ---------------- Top dwon FPN ----------------
  76. ## P5 -> P4
  77. self.top_down_layer_1 = RTCBlock(in_dim = self.out_dim * 2,
  78. out_dim = self.out_dim,
  79. num_blocks = round(3*depth),
  80. shortcut = False,
  81. act_type = act_type,
  82. norm_type = norm_type,
  83. depthwise = depthwise,
  84. )
  85. ## P4 -> P3
  86. self.top_down_layer_2 = RTCBlock(in_dim = self.out_dim * 2,
  87. out_dim = self.out_dim,
  88. num_blocks = round(3*depth),
  89. shortcut = False,
  90. act_type = act_type,
  91. norm_type = norm_type,
  92. depthwise = depthwise,
  93. )
  94. # ---------------- Bottom up PAN----------------
  95. ## P3 -> P4
  96. self.bottom_up_layer_1 = RTCBlock(in_dim = self.out_dim * 2,
  97. out_dim = self.out_dim,
  98. num_blocks = round(3*depth),
  99. shortcut = False,
  100. act_type = act_type,
  101. norm_type = norm_type,
  102. depthwise = depthwise,
  103. )
  104. ## P4 -> P5
  105. self.bottom_up_layer_2 = RTCBlock(in_dim = self.out_dim * 2,
  106. out_dim = self.out_dim,
  107. num_blocks = round(3*depth),
  108. shortcut = False,
  109. act_type = act_type,
  110. norm_type = norm_type,
  111. depthwise = depthwise,
  112. )
  113. self.init_weights()
  114. def init_weights(self):
  115. """Initialize the parameters."""
  116. for m in self.modules():
  117. if isinstance(m, torch.nn.Conv2d):
  118. # In order to be consistent with the source code,
  119. # reset the Conv2d initialization parameters
  120. m.reset_parameters()
  121. def forward(self, features):
  122. c3, c4, c5 = features
  123. # -------- Input projs --------
  124. p5 = self.reduce_layer_1(c5)
  125. p4 = self.reduce_layer_2(c4)
  126. p3 = self.reduce_layer_3(c3)
  127. # -------- Transformer encoder --------
  128. p5 = self.transformer_encoder(p5)
  129. # -------- Top down FPN --------
  130. p5_up = F.interpolate(p5, scale_factor=2.0)
  131. p4 = self.top_down_layer_1(torch.cat([p4, p5_up], dim=1))
  132. p4_up = F.interpolate(p4, scale_factor=2.0)
  133. p3 = self.top_down_layer_2(torch.cat([p3, p4_up], dim=1))
  134. # -------- Bottom up PAN --------
  135. p3_ds = self.dowmsample_layer_1(p3)
  136. p4 = self.bottom_up_layer_1(torch.cat([p4, p3_ds], dim=1))
  137. p4_ds = self.dowmsample_layer_2(p4)
  138. p5 = self.bottom_up_layer_2(torch.cat([p5, p4_ds], dim=1))
  139. out_feats = [p3, p4, p5]
  140. return out_feats
  141. if __name__ == '__main__':
  142. import time
  143. from thop import profile
  144. cfg = {
  145. 'width': 1.0,
  146. 'depth': 1.0,
  147. 'fpn': 'hybrid_encoder',
  148. 'fpn_act': 'silu',
  149. 'fpn_norm': 'BN',
  150. 'fpn_depthwise': False,
  151. 'en_num_heads': 8,
  152. 'en_num_layers': 1,
  153. 'en_mlp_ratio': 4.0,
  154. 'en_dropout': 0.1,
  155. 'pe_temperature': 10000.,
  156. 'en_act': 'gelu',
  157. }
  158. fpn_dims = [256, 512, 1024]
  159. out_dim = 256
  160. pyramid_feats = [torch.randn(1, fpn_dims[0], 80, 80), torch.randn(1, fpn_dims[1], 40, 40), torch.randn(1, fpn_dims[2], 20, 20)]
  161. model = build_fpn(cfg, fpn_dims, out_dim)
  162. t0 = time.time()
  163. outputs = model(pyramid_feats)
  164. t1 = time.time()
  165. print('Time: ', t1 - t0)
  166. for out in outputs:
  167. print(out.shape)
  168. print('==============================')
  169. flops, params = profile(model, inputs=(pyramid_feats, ), verbose=False)
  170. print('==============================')
  171. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  172. print('Params : {:.2f} M'.format(params / 1e6))