yolov7_fpn.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from .yolov7_basic import Conv, ELANBlockFPN, DownSample, RepConv
  5. # PaFPN-ELAN (YOLOv7's)
  6. class Yolov7PaFPN(nn.Module):
  7. def __init__(self,
  8. in_dims=[512, 1024, 512],
  9. out_dim=None,
  10. width=1.0,
  11. depth=1.0,
  12. nbranch=4.0,
  13. act_type='silu',
  14. norm_type='BN',
  15. depthwise=False):
  16. super(Yolov7PaFPN, self).__init__()
  17. self.in_dims = in_dims
  18. c3, c4, c5 = in_dims
  19. # top dwon
  20. ## P5 -> P4
  21. self.reduce_layer_1 = Conv(c5, round(256*width), k=1, norm_type=norm_type, act_type=act_type)
  22. self.reduce_layer_2 = Conv(c4, round(256*width), k=1, norm_type=norm_type, act_type=act_type)
  23. self.top_down_layer_1 = ELANBlockFPN(in_dim=round(256*width) + round(256*width),
  24. out_dim=round(256*width),
  25. expand_ratio=0.5,
  26. nbranch=nbranch,
  27. depth=depth,
  28. act_type=act_type,
  29. norm_type=norm_type,
  30. depthwise=depthwise
  31. )
  32. # P4 -> P3
  33. self.reduce_layer_3 = Conv(round(256*width), round(128*width), k=1, norm_type=norm_type, act_type=act_type)
  34. self.reduce_layer_4 = Conv(c3, round(128*width), k=1, norm_type=norm_type, act_type=act_type)
  35. self.top_down_layer_2 = ELANBlockFPN(in_dim=round(128*width) + round(128*width),
  36. out_dim=round(128*width),
  37. expand_ratio=0.5,
  38. nbranch=nbranch,
  39. depth=depth,
  40. act_type=act_type,
  41. norm_type=norm_type,
  42. depthwise=depthwise
  43. )
  44. # bottom up
  45. # P3 -> P4
  46. self.downsample_layer_1 = DownSample(in_dim=round(128*width), out_dim=round(256*width),
  47. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  48. self.bottom_up_layer_1 = ELANBlockFPN(in_dim=round(256*width) + round(256*width),
  49. out_dim=round(256*width),
  50. expand_ratio=0.5,
  51. nbranch=nbranch,
  52. depth=depth,
  53. act_type=act_type,
  54. norm_type=norm_type,
  55. depthwise=depthwise
  56. )
  57. # P4 -> P5
  58. self.downsample_layer_2 = DownSample(in_dim=round(256*width), out_dim=round(512*width),
  59. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  60. self.bottom_up_layer_2 = ELANBlockFPN(in_dim=round(512*width) + c5,
  61. out_dim=round(512*width),
  62. expand_ratio=0.5,
  63. nbranch=nbranch,
  64. depth=depth,
  65. act_type=act_type,
  66. norm_type=norm_type,
  67. depthwise=depthwise
  68. )
  69. # head conv
  70. self.head_conv_1 = RepConv(round(128*width), round(256*width), k=3, s=1, p=1, act_type=act_type)
  71. self.head_conv_2 = RepConv(round(256*width), round(512*width), k=3, s=1, p=1, act_type=act_type)
  72. self.head_conv_3 = RepConv(round(512*width), round(1024*width), k=3, s=1, p=1, act_type=act_type)
  73. # output proj layers
  74. if out_dim is not None:
  75. self.out_layers = nn.ModuleList([
  76. Conv(in_dim, out_dim, k=1,
  77. norm_type=norm_type, act_type=act_type)
  78. for in_dim in [round(256*width), round(512*width), round(1024*width)]
  79. ])
  80. self.out_dim = [out_dim] * 3
  81. else:
  82. self.out_layers = None
  83. self.out_dim = [round(256*width), round(512*width), round(1024*width)]
  84. def forward(self, features):
  85. c3, c4, c5 = features
  86. # Top down
  87. ## P5 -> P4
  88. c6 = self.reduce_layer_1(c5)
  89. c7 = F.interpolate(c6, scale_factor=2.0)
  90. c8 = torch.cat([c7, self.reduce_layer_2(c4)], dim=1)
  91. c9 = self.top_down_layer_1(c8)
  92. ## P4 -> P3
  93. c10 = self.reduce_layer_3(c9)
  94. c11 = F.interpolate(c10, scale_factor=2.0)
  95. c12 = torch.cat([c11, self.reduce_layer_4(c3)], dim=1)
  96. c13 = self.top_down_layer_2(c12)
  97. # Bottom up
  98. ## p3 -> P4
  99. c14 = self.downsample_layer_1(c13)
  100. c15 = torch.cat([c14, c9], dim=1)
  101. c16 = self.bottom_up_layer_1(c15)
  102. ## P4 -> P5
  103. c17 = self.downsample_layer_2(c16)
  104. c18 = torch.cat([c17, c5], dim=1)
  105. c19 = self.bottom_up_layer_2(c18)
  106. c20 = self.head_conv_1(c13)
  107. c21 = self.head_conv_2(c16)
  108. c22 = self.head_conv_3(c19)
  109. out_feats = [c20, c21, c22] # [P3, P4, P5]
  110. # output proj layers
  111. if self.out_layers is not None:
  112. out_feats_proj = []
  113. for feat, layer in zip(out_feats, self.out_layers):
  114. out_feats_proj.append(layer(feat))
  115. return out_feats_proj
  116. return out_feats
  117. def build_fpn(cfg, in_dims, out_dim=None):
  118. model = cfg['fpn']
  119. # build pafpn
  120. if model == 'yolov7_pafpn':
  121. fpn_net = Yolov7PaFPN(in_dims=in_dims,
  122. out_dim=out_dim,
  123. width=cfg['width'],
  124. depth=cfg['depth'],
  125. nbranch=cfg['nbranch'],
  126. act_type=cfg['fpn_act'],
  127. norm_type=cfg['fpn_norm'],
  128. depthwise=cfg['fpn_depthwise']
  129. )
  130. return fpn_net