fcos_fpn.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from typing import List
  5. # ------------------ Basic Feature Pyramid Network ------------------
  6. class FcosFPN(nn.Module):
  7. def __init__(self, cfg, in_dims: List = [512, 1024, 2048]):
  8. super().__init__()
  9. # ------------------ Basic parameters -------------------
  10. self.out_dim = cfg.head_dim
  11. # ------------------ Network parameters -------------------
  12. self.input_proj_1 = nn.Conv2d(in_dims[0], self.out_dim, kernel_size=1)
  13. self.input_proj_2 = nn.Conv2d(in_dims[1], self.out_dim, kernel_size=1)
  14. self.input_proj_3 = nn.Conv2d(in_dims[2], self.out_dim, kernel_size=1)
  15. self.smooth_layer_1 = nn.Conv2d(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=1)
  16. self.smooth_layer_2 = nn.Conv2d(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=1)
  17. self.smooth_layer_3 = nn.Conv2d(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=1)
  18. self.p6_conv = nn.Conv2d(self.out_dim, self.out_dim, kernel_size=3, stride=2, padding=1)
  19. def forward(self, feats):
  20. """
  21. feats: (List of Tensor) [C3, C4, C5]
  22. """
  23. c3, c4, c5 = feats
  24. # -------- Input projection --------
  25. p3 = self.input_proj_1(c3)
  26. p4 = self.input_proj_2(c4)
  27. p5 = self.input_proj_3(c5)
  28. # -------- Feature fusion --------
  29. outputs = [self.smooth_layer_3(p5)]
  30. # P5 -> P4
  31. p4 = p4 + F.interpolate(p5, size=p4.shape[2:], mode='nearest')
  32. outputs.insert(0, self.smooth_layer_2(p4))
  33. # P4 -> P3
  34. p3 = p3 + F.interpolate(p4, size=p3.shape[2:], mode='nearest')
  35. outputs.insert(0, self.smooth_layer_1(p3))
  36. # P5 -> P6
  37. outputs.append(self.p6_conv(outputs[-1]))
  38. # [P3, P4, P5, P6]
  39. return outputs
  40. if __name__=='__main__':
  41. import time
  42. from thop import profile
  43. # Model config
  44. # YOLOv2-Base config
  45. class FcosBaseConfig(object):
  46. def __init__(self) -> None:
  47. # ---------------- Model config ----------------
  48. self.width = 0.50
  49. self.depth = 0.34
  50. self.out_stride = [8, 16, 32, 64]
  51. ## Head
  52. self.head_dim = 256
  53. cfg = FcosBaseConfig()
  54. # Build a head
  55. in_dims = [128, 256, 512]
  56. fpn = FcosFPN(cfg, in_dims)
  57. # Inference
  58. x = [torch.randn(1, in_dims[0], 80, 80),
  59. torch.randn(1, in_dims[1], 40, 40),
  60. torch.randn(1, in_dims[2], 20, 20)]
  61. t0 = time.time()
  62. output = fpn(x)
  63. t1 = time.time()
  64. print('Time: ', t1 - t0)
  65. print('====== FPN output ====== ')
  66. for level, feat in enumerate(output):
  67. print("- Level-{} : ".format(level), feat.shape)
  68. flops, params = profile(fpn, inputs=(x, ), verbose=False)
  69. print('==============================')
  70. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  71. print('Params : {:.2f} M'.format(params / 1e6))