fpn.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from typing import List
  5. from .conv import BasicConv, ELANLayer
  6. from .transformer import TransformerEncoder
  7. # Build PaFPN
  8. def build_fpn(cfg, in_dims):
  9. if cfg.fpn == 'hybrid_encoder':
  10. return HybridEncoder(in_dims = in_dims,
  11. out_dim = cfg.hidden_dim,
  12. num_blocks = cfg.fpn_num_blocks,
  13. expand_ratio= cfg.fpn_expand_ratio,
  14. act_type = cfg.fpn_act,
  15. norm_type = cfg.fpn_norm,
  16. depthwise = cfg.fpn_depthwise,
  17. num_heads = cfg.en_num_heads,
  18. num_layers = cfg.en_num_layers,
  19. ffn_dim = cfg.en_ffn_dim,
  20. dropout = cfg.en_dropout,
  21. en_act_type = cfg.en_act,
  22. )
  23. else:
  24. raise NotImplementedError("Unknown PaFPN: <{}>".format(cfg.fpn))
  25. # ----------------- Feature Pyramid Network -----------------## Hybrid Encoder (Transformer encoder + Convolutional PaFPN)
  26. class HybridEncoder(nn.Module):
  27. def __init__(self,
  28. in_dims :List = [256, 512, 1024],
  29. out_dim :int = 256,
  30. num_blocks :int = 3,
  31. expand_ratio :float = 0.5,
  32. act_type :str = 'silu',
  33. norm_type :str = 'BN',
  34. depthwise :bool = False,
  35. # Transformer's parameters
  36. num_heads :int = 8,
  37. num_layers :int = 1,
  38. ffn_dim :int = 1024,
  39. dropout :float = 0.1,
  40. pe_temperature :float = 10000.,
  41. en_act_type :str = 'gelu'
  42. ) -> None:
  43. super(HybridEncoder, self).__init__()
  44. print('==============================')
  45. print('FPN: {}'.format("RTC-PaFPN"))
  46. # ---------------- Basic parameters ----------------
  47. self.in_dims = in_dims
  48. self.out_dim = out_dim
  49. self.out_dims = [self.out_dim] * len(in_dims)
  50. self.num_heads = num_heads
  51. self.num_layers = num_layers
  52. self.ffn_dim = ffn_dim
  53. c3, c4, c5 = in_dims
  54. # ---------------- Input projs ----------------
  55. self.reduce_layer_1 = BasicConv(c5, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  56. self.reduce_layer_2 = BasicConv(c4, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  57. self.reduce_layer_3 = BasicConv(c3, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  58. # ---------------- Downsample ----------------
  59. self.dowmsample_layer_1 = BasicConv(self.out_dim, self.out_dim,
  60. kernel_size=3, padding=1, stride=2,
  61. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  62. self.dowmsample_layer_2 = BasicConv(self.out_dim, self.out_dim,
  63. kernel_size=3, padding=1, stride=2,
  64. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  65. # ---------------- Transformer Encoder ----------------
  66. self.transformer_encoder = TransformerEncoder(d_model = self.out_dim,
  67. num_heads = num_heads,
  68. num_layers = num_layers,
  69. ffn_dim = ffn_dim,
  70. pe_temperature = pe_temperature,
  71. dropout = dropout,
  72. act_type = en_act_type
  73. )
  74. # ---------------- Top dwon FPN ----------------
  75. ## P5 -> P4
  76. self.top_down_layer_1 = ELANLayer(in_dim = self.out_dim * 2,
  77. out_dim = self.out_dim,
  78. num_blocks = num_blocks,
  79. expand_ratio = expand_ratio,
  80. shortcut = False,
  81. act_type = act_type,
  82. norm_type = norm_type,
  83. depthwise = depthwise,
  84. )
  85. ## P4 -> P3
  86. self.top_down_layer_2 = ELANLayer(in_dim = self.out_dim * 2,
  87. out_dim = self.out_dim,
  88. num_blocks = num_blocks,
  89. expand_ratio = expand_ratio,
  90. shortcut = False,
  91. act_type = act_type,
  92. norm_type = norm_type,
  93. depthwise = depthwise,
  94. )
  95. # ---------------- Bottom up PAN----------------
  96. ## P3 -> P4
  97. self.bottom_up_layer_1 = ELANLayer(in_dim = self.out_dim * 2,
  98. out_dim = self.out_dim,
  99. num_blocks = num_blocks,
  100. expand_ratio = expand_ratio,
  101. shortcut = False,
  102. act_type = act_type,
  103. norm_type = norm_type,
  104. depthwise = depthwise,
  105. )
  106. ## P4 -> P5
  107. self.bottom_up_layer_2 = ELANLayer(in_dim = self.out_dim * 2,
  108. out_dim = self.out_dim,
  109. num_blocks = num_blocks,
  110. expand_ratio = expand_ratio,
  111. shortcut = False,
  112. act_type = act_type,
  113. norm_type = norm_type,
  114. depthwise = depthwise,
  115. )
  116. self.init_weights()
  117. def init_weights(self):
  118. """Initialize the parameters."""
  119. for m in self.modules():
  120. if isinstance(m, torch.nn.Conv2d):
  121. # In order to be consistent with the source code,
  122. # reset the Conv2d initialization parameters
  123. m.reset_parameters()
  124. def forward(self, features):
  125. c3, c4, c5 = features
  126. # -------- Input projs --------
  127. p5 = self.reduce_layer_1(c5)
  128. p4 = self.reduce_layer_2(c4)
  129. p3 = self.reduce_layer_3(c3)
  130. # -------- Transformer encoder --------
  131. p5 = self.transformer_encoder(p5)
  132. # -------- Top down FPN --------
  133. p5_up = F.interpolate(p5, scale_factor=2.0)
  134. p4 = self.top_down_layer_1(torch.cat([p4, p5_up], dim=1))
  135. p4_up = F.interpolate(p4, scale_factor=2.0)
  136. p3 = self.top_down_layer_2(torch.cat([p3, p4_up], dim=1))
  137. # -------- Bottom up PAN --------
  138. p3_ds = self.dowmsample_layer_1(p3)
  139. p4 = self.bottom_up_layer_1(torch.cat([p4, p3_ds], dim=1))
  140. p4_ds = self.dowmsample_layer_2(p4)
  141. p5 = self.bottom_up_layer_2(torch.cat([p5, p4_ds], dim=1))
  142. out_feats = [p3, p4, p5]
  143. return out_feats