yolov3_fpn.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from typing import List
  5. try:
  6. from .modules import ConvModule, ConvBlocks
  7. except:
  8. from modules import ConvModule, ConvBlocks
  9. # Yolov3FPN
  10. class Yolov3FPN(nn.Module):
  11. def __init__(self,
  12. in_dims: List = [256, 512, 1024],
  13. head_dim: int = 256,
  14. ):
  15. super(Yolov3FPN, self).__init__()
  16. self.in_dims = in_dims
  17. self.head_dim = head_dim
  18. self.fpn_out_dims = [head_dim] * 3
  19. c3, c4, c5 = in_dims
  20. # P5 -> P4
  21. self.top_down_layer_1 = ConvBlocks(c5, 512)
  22. self.reduce_layer_1 = ConvModule(512, 256, kernel_size=1)
  23. # P4 -> P3
  24. self.top_down_layer_2 = ConvBlocks(c4 + 256, 256)
  25. self.reduce_layer_2 = ConvModule(256, 128, kernel_size=1)
  26. # P3
  27. self.top_down_layer_3 = ConvBlocks(c3 + 128, 128)
  28. # output proj layers
  29. self.out_layers = nn.ModuleList([ConvModule(in_dim, head_dim, kernel_size=1)
  30. for in_dim in [128, 256, 512]
  31. ])
  32. def forward(self, features):
  33. c3, c4, c5 = features
  34. # p5/32
  35. p5 = self.top_down_layer_1(c5)
  36. # p4/16
  37. p5_up = F.interpolate(self.reduce_layer_1(p5), scale_factor=2.0)
  38. p4 = self.top_down_layer_2(torch.cat([c4, p5_up], dim=1))
  39. # P3/8
  40. p4_up = F.interpolate(self.reduce_layer_2(p4), scale_factor=2.0)
  41. p3 = self.top_down_layer_3(torch.cat([c3, p4_up], dim=1))
  42. out_feats = [p3, p4, p5]
  43. # output proj layers
  44. out_feats_proj = []
  45. for feat, layer in zip(out_feats, self.out_layers):
  46. out_feats_proj.append(layer(feat))
  47. return out_feats_proj
  48. if __name__=='__main__':
  49. import time
  50. from thop import profile
  51. # Model config
  52. # Build a head
  53. in_dims = [128, 256, 512]
  54. fpn = Yolov3FPN(in_dims, head_dim=256)
  55. # Randomly generate a input data
  56. x = [torch.randn(1, in_dims[0], 80, 80),
  57. torch.randn(1, in_dims[1], 40, 40),
  58. torch.randn(1, in_dims[2], 20, 20)]
  59. # Inference
  60. t0 = time.time()
  61. output = fpn(x)
  62. t1 = time.time()
  63. print('Time: ', t1 - t0)
  64. print('====== FPN output ====== ')
  65. for level, feat in enumerate(output):
  66. print("- Level-{} : ".format(level), feat.shape)
  67. flops, params = profile(fpn, inputs=(x, ), verbose=False)
  68. print('==============================')
  69. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  70. print('Params : {:.2f} M'.format(params / 1e6))