yolov7_pafpn.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. from typing import List
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from .yolov7_basic import BasicConv, ELANLayerFPN, MDown
  6. # PaFPN-ELAN (YOLOv7's)
  7. class Yolov7PaFPN(nn.Module):
  8. def __init__(self, cfg, in_dims: List = [512, 1024, 512]):
  9. super(Yolov7PaFPN, self).__init__()
  10. # ----------------------------- Basic parameters -----------------------------
  11. self.in_dims = in_dims
  12. self.out_dims = [round(256*cfg.width), round(512*cfg.width), round(1024*cfg.width)]
  13. c3, c4, c5 = in_dims
  14. # ----------------------------- Yolov7's Top-down FPN -----------------------------
  15. ## P5 -> P4
  16. self.reduce_layer_1 = BasicConv(c5, round(256*cfg.width),
  17. kernel_size=1, act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
  18. self.reduce_layer_2 = BasicConv(c4, round(256*cfg.width),
  19. kernel_size=1, act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
  20. self.top_down_layer_1 = ELANLayerFPN(in_dim = round(256*cfg.width) + round(256*cfg.width),
  21. out_dim = round(256*cfg.width),
  22. expansions = cfg.fpn_expansions,
  23. branch_width = cfg.fpn_block_bw,
  24. branch_depth = cfg.fpn_block_dw,
  25. act_type = cfg.fpn_act,
  26. norm_type = cfg.fpn_norm,
  27. depthwise = cfg.fpn_depthwise,
  28. )
  29. ## P4 -> P3
  30. self.reduce_layer_3 = BasicConv(round(256*cfg.width), round(128*cfg.width),
  31. kernel_size=1, act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
  32. self.reduce_layer_4 = BasicConv(c3, round(128*cfg.width),
  33. kernel_size=1, act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
  34. self.top_down_layer_2 = ELANLayerFPN(in_dim = round(128*cfg.width) + round(128*cfg.width),
  35. out_dim = round(128*cfg.width),
  36. expansions = cfg.fpn_expansions,
  37. branch_width = cfg.fpn_block_bw,
  38. branch_depth = cfg.fpn_block_dw,
  39. act_type = cfg.fpn_act,
  40. norm_type = cfg.fpn_norm,
  41. depthwise = cfg.fpn_depthwise,
  42. )
  43. # ----------------------------- Yolov7's Bottom-up PAN -----------------------------
  44. ## P3 -> P4
  45. self.downsample_layer_1 = MDown(round(128*cfg.width), round(256*cfg.width),
  46. act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
  47. self.bottom_up_layer_1 = ELANLayerFPN(in_dim = round(256*cfg.width) + round(256*cfg.width),
  48. out_dim = round(256*cfg.width),
  49. expansions = cfg.fpn_expansions,
  50. branch_width = cfg.fpn_block_bw,
  51. branch_depth = cfg.fpn_block_dw,
  52. act_type = cfg.fpn_act,
  53. norm_type = cfg.fpn_norm,
  54. depthwise = cfg.fpn_depthwise,
  55. )
  56. ## P4 -> P5
  57. self.downsample_layer_2 = MDown(round(256*cfg.width), round(512*cfg.width),
  58. act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
  59. self.bottom_up_layer_2 = ELANLayerFPN(in_dim = round(512*cfg.width) + c5,
  60. out_dim = round(512*cfg.width),
  61. expansions = cfg.fpn_expansions,
  62. branch_width = cfg.fpn_block_bw,
  63. branch_depth = cfg.fpn_block_dw,
  64. act_type = cfg.fpn_act,
  65. norm_type = cfg.fpn_norm,
  66. depthwise = cfg.fpn_depthwise,
  67. )
  68. # ----------------------------- Head conv layers -----------------------------
  69. ## Head convs
  70. self.head_conv_1 = BasicConv(round(128*cfg.width), round(256*cfg.width),
  71. kernel_size=3, padding=1, stride=1,
  72. act_type=cfg.fpn_act, norm_type=cfg.fpn_norm, depthwise=cfg.fpn_depthwise)
  73. self.head_conv_2 = BasicConv(round(256*cfg.width), round(512*cfg.width),
  74. kernel_size=3, padding=1, stride=1,
  75. act_type=cfg.fpn_act, norm_type=cfg.fpn_norm, depthwise=cfg.fpn_depthwise)
  76. self.head_conv_3 = BasicConv(round(512*cfg.width), round(1024*cfg.width),
  77. kernel_size=3, padding=1, stride=1,
  78. act_type=cfg.fpn_act, norm_type=cfg.fpn_norm, depthwise=cfg.fpn_depthwise)
  79. def forward(self, features):
  80. c3, c4, c5 = features
  81. # ------------------ Top down FPN ------------------
  82. ## P5 -> P4
  83. p5 = self.reduce_layer_1(c5)
  84. p5_up = F.interpolate(p5, scale_factor=2.0)
  85. p4 = self.reduce_layer_2(c4)
  86. p4 = self.top_down_layer_1(torch.cat([p5_up, p4], dim=1))
  87. ## P4 -> P3
  88. p4_in = self.reduce_layer_3(p4)
  89. p4_up = F.interpolate(p4_in, scale_factor=2.0)
  90. p3 = self.reduce_layer_4(c3)
  91. p3 = self.top_down_layer_2(torch.cat([p4_up, p3], dim=1))
  92. # ------------------ Bottom up PAN ------------------
  93. ## P3 -> P4
  94. p3_ds = self.downsample_layer_1(p3)
  95. p4 = torch.cat([p3_ds, p4], dim=1)
  96. p4 = self.bottom_up_layer_1(p4)
  97. ## P4 -> P5
  98. p4_ds = self.downsample_layer_2(p4)
  99. p5 = torch.cat([p4_ds, c5], dim=1)
  100. p5 = self.bottom_up_layer_2(p5)
  101. out_feats = [self.head_conv_1(p3), self.head_conv_2(p4), self.head_conv_3(p5)]
  102. return out_feats