pafpn.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from .basic import BasicConv, RTCBlock
  5. # Build PaFPN
  6. def build_pafpn(cfg, in_dims, out_dim):
  7. return
  8. # ----------------- Feature Pyramid Network -----------------
  9. ## Real-time Convolutional PaFPN
  10. class HybridEncoder(nn.Module):
  11. def __init__(self,
  12. in_dims = [256, 512, 512],
  13. out_dim = 256,
  14. width = 1.0,
  15. depth = 1.0,
  16. act_type = 'silu',
  17. norm_type = 'BN',
  18. depthwise = False):
  19. super(HybridEncoder, self).__init__()
  20. print('==============================')
  21. print('FPN: {}'.format("RTC-PaFPN"))
  22. # ---------------- Basic parameters ----------------
  23. self.in_dims = in_dims
  24. self.out_dim = round(out_dim * width)
  25. self.width = width
  26. self.depth = depth
  27. c3, c4, c5 = in_dims
  28. # ---------------- Input projs ----------------
  29. self.reduce_layer_1 = BasicConv(c5, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  30. self.reduce_layer_2 = BasicConv(c4, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  31. self.reduce_layer_3 = BasicConv(c3, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  32. # ---------------- Downsample ----------------
  33. self.dowmsample_layer_1 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
  34. self.dowmsample_layer_2 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
  35. # ---------------- Top dwon FPN ----------------
  36. ## P5 -> P4
  37. self.top_down_layer_1 = RTCBlock(in_dim = self.out_dim * 2,
  38. out_dim = self.out_dim,
  39. num_blocks = round(3*depth),
  40. shortcut = False,
  41. act_type = act_type,
  42. norm_type = norm_type,
  43. depthwise = depthwise,
  44. )
  45. ## P4 -> P3
  46. self.top_down_layer_2 = RTCBlock(in_dim = self.out_dim * 2,
  47. out_dim = self.out_dim,
  48. num_blocks = round(3*depth),
  49. shortcut = False,
  50. act_type = act_type,
  51. norm_type = norm_type,
  52. depthwise = depthwise,
  53. )
  54. # ---------------- Bottom up PAN----------------
  55. ## P3 -> P4
  56. self.bottom_up_layer_1 = RTCBlock(in_dim = self.out_dim * 2,
  57. out_dim = self.out_dim,
  58. num_blocks = round(3*depth),
  59. shortcut = False,
  60. act_type = act_type,
  61. norm_type = norm_type,
  62. depthwise = depthwise,
  63. )
  64. ## P4 -> P5
  65. self.bottom_up_layer_2 = RTCBlock(in_dim = self.out_dim * 2,
  66. out_dim = self.out_dim,
  67. num_blocks = round(3*depth),
  68. shortcut = False,
  69. act_type = act_type,
  70. norm_type = norm_type,
  71. depthwise = depthwise,
  72. )
  73. self.init_weights()
  74. def init_weights(self):
  75. """Initialize the parameters."""
  76. for m in self.modules():
  77. if isinstance(m, torch.nn.Conv2d):
  78. # In order to be consistent with the source code,
  79. # reset the Conv2d initialization parameters
  80. m.reset_parameters()
  81. def forward(self, features):
  82. c3, c4, c5 = features
  83. # -------- Input projs --------
  84. p5 = self.reduce_layer_1(c5)
  85. p4 = self.reduce_layer_2(c4)
  86. p3 = self.reduce_layer_3(c3)
  87. # -------- Top down FPN --------
  88. p5_up = F.interpolate(p5, scale_factor=2.0)
  89. p4 = self.top_down_layer_1(torch.cat([p4, p5_up], dim=1))
  90. p4_up = F.interpolate(p4, scale_factor=2.0)
  91. p3 = self.top_down_layer_2(torch.cat([p3, p4_up], dim=1))
  92. # -------- Bottom up PAN --------
  93. p3_ds = self.dowmsample_layer_1(p3)
  94. p4 = self.bottom_up_layer_1(torch.cat([p4, p3_ds], dim=1))
  95. p4_ds = self.dowmsample_layer_2(p4)
  96. p5 = self.bottom_up_layer_2(torch.cat([p5, p4_ds], dim=1))
  97. out_feats = [p3, p4, p5]
  98. return out_feats