rtcdetv2_pafpn.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. try:
  5. from .rtcdetv2_basic import Conv, ResXStage
  6. except:
  7. from rtcdetv2_basic import Conv, ResXStage
  8. # PaFPN-CSP
  9. class RTCDetv2PaFPN(nn.Module):
  10. def __init__(self,
  11. in_dims=[256, 512, 1024],
  12. out_dim=256,
  13. width=1.0,
  14. depth=1.0,
  15. act_type='silu',
  16. norm_type='BN',
  17. depthwise=False):
  18. super(RTCDetv2PaFPN, self).__init__()
  19. # ------------- Basic parameters -------------
  20. self.in_dims = in_dims
  21. self.out_dim = out_dim
  22. self.expand_ratios = [0.25, 0.25, 0.25, 0.25]
  23. self.ffn_ratios = [4.0, 4.0, 4.0, 4.0]
  24. self.num_branches = [4, 4, 4, 4]
  25. self.num_blocks = [round(2 * depth), round(2 * depth), round(2 * depth), round(2 * depth)]
  26. c3, c4, c5 = in_dims
  27. # top down
  28. ## P5 -> P4
  29. self.reduce_layer_1 = Conv(c5, round(384*width), k=1, act_type=act_type, norm_type=norm_type)
  30. self.top_down_layer_1 = ResXStage(in_dim = c4 + round(384*width),
  31. out_dim = int(384*width),
  32. expand_ratio = self.expand_ratios[0],
  33. ffn_ratio = self.ffn_ratios[0],
  34. num_branches = self.num_branches[0],
  35. num_blocks = self.num_blocks[0],
  36. shortcut = False,
  37. act_type = act_type,
  38. norm_type = norm_type,
  39. depthwise = depthwise
  40. )
  41. ## P4 -> P3
  42. self.reduce_layer_2 = Conv(c4, round(192*width), k=1, norm_type=norm_type, act_type=act_type)
  43. self.top_down_layer_2 = ResXStage(in_dim = c3 + round(192*width),
  44. out_dim = round(192*width),
  45. expand_ratio = self.expand_ratios[1],
  46. ffn_ratio = self.ffn_ratios[1],
  47. num_branches = self.num_branches[1],
  48. num_blocks = self.num_blocks[1],
  49. shortcut = False,
  50. act_type = act_type,
  51. norm_type = norm_type,
  52. depthwise = depthwise
  53. )
  54. # bottom up
  55. ## P3 -> P4
  56. self.downsample_layer_1 = Conv(round(192*width), round(192*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  57. self.bottom_up_layer_1 = ResXStage(in_dim = round(192*width) + round(192*width),
  58. out_dim = round(384*width),
  59. expand_ratio = self.expand_ratios[2],
  60. ffn_ratio = self.ffn_ratios[2],
  61. num_branches = self.num_branches[2],
  62. num_blocks = self.num_blocks[2],
  63. shortcut = False,
  64. act_type = act_type,
  65. norm_type = norm_type,
  66. depthwise = depthwise
  67. )
  68. ## P4 -> P5
  69. self.downsample_layer_2 = Conv(round(384*width), round(384*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  70. self.bottom_up_layer_2 = ResXStage(in_dim = round(384*width) + round(384*width),
  71. out_dim = round(768*width),
  72. expand_ratio = self.expand_ratios[3],
  73. ffn_ratio = self.ffn_ratios[3],
  74. num_branches = self.num_branches[3],
  75. num_blocks = self.num_blocks[3],
  76. shortcut = False,
  77. act_type = act_type,
  78. norm_type = norm_type,
  79. depthwise = depthwise
  80. )
  81. # output proj layers
  82. if out_dim is not None:
  83. # output proj layers
  84. self.out_layers = nn.ModuleList([
  85. Conv(in_dim, out_dim, k=1,
  86. norm_type=norm_type, act_type=act_type)
  87. for in_dim in [round(192 * width), round(384 * width), round(768 * width)]
  88. ])
  89. self.out_dim = [out_dim] * 3
  90. else:
  91. self.out_layers = None
  92. self.out_dim = [round(192 * width), round(384 * width), round(768 * width)]
  93. def forward(self, features):
  94. c3, c4, c5 = features
  95. c6 = self.reduce_layer_1(c5)
  96. c7 = F.interpolate(c6, scale_factor=2.0) # s32->s16
  97. c8 = torch.cat([c7, c4], dim=1)
  98. c9 = self.top_down_layer_1(c8)
  99. # P3/8
  100. c10 = self.reduce_layer_2(c9)
  101. c11 = F.interpolate(c10, scale_factor=2.0) # s16->s8
  102. c12 = torch.cat([c11, c3], dim=1)
  103. c13 = self.top_down_layer_2(c12) # to det
  104. # p4/16
  105. c14 = self.downsample_layer_1(c13)
  106. c15 = torch.cat([c14, c10], dim=1)
  107. c16 = self.bottom_up_layer_1(c15) # to det
  108. # p5/32
  109. c17 = self.downsample_layer_2(c16)
  110. c18 = torch.cat([c17, c6], dim=1)
  111. c19 = self.bottom_up_layer_2(c18) # to det
  112. out_feats = [c13, c16, c19] # [P3, P4, P5]
  113. # output proj layers
  114. if self.out_layers is not None:
  115. # output proj layers
  116. out_feats_proj = []
  117. for feat, layer in zip(out_feats, self.out_layers):
  118. out_feats_proj.append(layer(feat))
  119. return out_feats_proj
  120. return out_feats
  121. def build_fpn(cfg, in_dims, out_dim=None):
  122. model = cfg['fpn']
  123. # build neck
  124. if model == 'rtcdetv2_pafpn':
  125. fpn_net = RTCDetv2PaFPN(in_dims = in_dims,
  126. out_dim = out_dim,
  127. width = cfg['width'],
  128. depth = cfg['depth'],
  129. act_type = cfg['fpn_act'],
  130. norm_type = cfg['fpn_norm'],
  131. depthwise = cfg['fpn_depthwise']
  132. )
  133. return fpn_net
  134. if __name__ == '__main__':
  135. import time
  136. from thop import profile
  137. cfg = {
  138. 'width': 1.0,
  139. 'depth': 1.0,
  140. 'fpn': 'rtcdetv2_pafpn',
  141. 'fpn_act': 'silu',
  142. 'fpn_norm': 'BN',
  143. 'fpn_depthwise': False,
  144. }
  145. fpn_dims = [192, 384, 768]
  146. out_dim = 192
  147. # Head-1
  148. model = build_fpn(cfg, fpn_dims, out_dim)
  149. fpn_feats = [torch.randn(1, fpn_dims[0], 80, 80), torch.randn(1, fpn_dims[1], 40, 40), torch.randn(1, fpn_dims[2], 20, 20)]
  150. t0 = time.time()
  151. outputs = model(fpn_feats)
  152. t1 = time.time()
  153. print('Time: ', t1 - t0)
  154. # for out in outputs:
  155. # print(out.shape)
  156. print('==============================')
  157. flops, params = profile(model, inputs=(fpn_feats, ), verbose=False)
  158. print('==============================')
  159. print('FPN: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  160. print('FPN: Params : {:.2f} M'.format(params / 1e6))