yolov3_fpn.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from .yolov3_basic import Conv, ConvBlocks
  5. # YoloFPN
  6. class YoloFPN(nn.Module):
  7. def __init__(self,
  8. in_dims=[256, 512, 1024],
  9. width=1.0,
  10. depth=1.0,
  11. out_dim=None,
  12. act_type='silu',
  13. norm_type='BN'):
  14. super(YoloFPN, self).__init__()
  15. self.in_dims = in_dims
  16. self.out_dim = out_dim
  17. c3, c4, c5 = in_dims
  18. # P5 -> P4
  19. self.top_down_layer_1 = ConvBlocks(c5, int(512*width), act_type=act_type, norm_type=norm_type)
  20. self.reduce_layer_1 = Conv(int(512*width), int(256*width), k=1, act_type=act_type, norm_type=norm_type)
  21. # P4 -> P3
  22. self.top_down_layer_2 = ConvBlocks(c4 + int(256*width), int(256*width), act_type=act_type, norm_type=norm_type)
  23. self.reduce_layer_2 = Conv(int(256*width), int(128*width), k=1, act_type=act_type, norm_type=norm_type)
  24. # P3
  25. self.top_down_layer_3 = ConvBlocks(c3 + int(128*width), int(128*width), act_type=act_type, norm_type=norm_type)
  26. # output proj layers
  27. if out_dim is not None:
  28. # output proj layers
  29. self.out_layers = nn.ModuleList([
  30. Conv(in_dim, out_dim, k=1,
  31. norm_type=norm_type, act_type=act_type)
  32. for in_dim in [int(128 * width), int(256 * width), int(512 * width)]
  33. ])
  34. self.out_dim = [out_dim] * 3
  35. else:
  36. self.out_layers = None
  37. self.out_dim = [int(128 * width), int(256 * width), int(512 * width)]
  38. def forward(self, features):
  39. c3, c4, c5 = features
  40. # p5/32
  41. p5 = self.top_down_layer_1(c5)
  42. # p4/16
  43. p5_up = F.interpolate(self.reduce_layer_1(p5), scale_factor=2.0)
  44. p4 = self.top_down_layer_2(torch.cat([c4, p5_up], dim=1))
  45. # P3/8
  46. p4_up = F.interpolate(self.reduce_layer_2(p4), scale_factor=2.0)
  47. p3 = self.top_down_layer_3(torch.cat([c3, p4_up], dim=1))
  48. out_feats = [p3, p4, p5]
  49. # output proj layers
  50. if self.out_layers is not None:
  51. # output proj layers
  52. out_feats_proj = []
  53. for feat, layer in zip(out_feats, self.out_layers):
  54. out_feats_proj.append(layer(feat))
  55. return out_feats_proj
  56. return out_feats
  57. def build_fpn(cfg, in_dims, out_dim=None):
  58. model = cfg['fpn']
  59. # build neck
  60. if model == 'yolo_fpn':
  61. fpn_net = YoloFPN(in_dims=in_dims,
  62. out_dim=out_dim,
  63. width=cfg['width'],
  64. depth=cfg['depth'],
  65. act_type=cfg['fpn_act'],
  66. norm_type=cfg['fpn_norm']
  67. )
  68. return fpn_net