fpn.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. from utils import weight_init
  4. # ------------------ Basic Feature Pyramid Network ------------------
  5. class BasicFPN(nn.Module):
  6. def __init__(self,
  7. in_dims=[512, 1024, 2048],
  8. out_dim=256,
  9. p6_feat=False,
  10. p7_feat=False,
  11. from_c5=False,
  12. ):
  13. super().__init__()
  14. # ------------------ Basic parameters -------------------
  15. self.p6_feat = p6_feat
  16. self.p7_feat = p7_feat
  17. self.from_c5 = from_c5
  18. # ------------------ Network parameters -------------------
  19. ## latter layers
  20. self.input_projs = nn.ModuleList()
  21. self.smooth_layers = nn.ModuleList()
  22. for in_dim in in_dims[::-1]:
  23. self.input_projs.append(nn.Conv2d(in_dim, out_dim, kernel_size=1))
  24. self.smooth_layers.append(nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1))
  25. ## P6/P7 layers
  26. if p6_feat:
  27. if from_c5:
  28. self.p6_conv = nn.Conv2d(in_dims[-1], out_dim, kernel_size=3, stride=2, padding=1)
  29. else: # from p5
  30. self.p6_conv = nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1)
  31. if p7_feat:
  32. self.p7_conv = nn.Sequential(
  33. nn.ReLU(inplace=True),
  34. nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1)
  35. )
  36. self._init_weight()
  37. def _init_weight(self):
  38. for m in self.modules():
  39. if isinstance(m, nn.Conv2d):
  40. weight_init.c2_xavier_fill(m)
  41. def forward(self, feats):
  42. """
  43. feats: (List of Tensor) [C3, C4, C5], C_i ∈ R^(B x C_i x H_i x W_i)
  44. """
  45. outputs = []
  46. # [C3, C4, C5] -> [C5, C4, C3]
  47. feats = feats[::-1]
  48. top_level_feat = feats[0]
  49. prev_feat = self.input_projs[0](top_level_feat)
  50. outputs.append(self.smooth_layers[0](prev_feat))
  51. for feat, input_proj, smooth_layer in zip(feats[1:], self.input_projs[1:], self.smooth_layers[1:]):
  52. feat = input_proj(feat)
  53. top_down_feat = F.interpolate(prev_feat, size=feat.shape[2:], mode='nearest')
  54. prev_feat = feat + top_down_feat
  55. outputs.insert(0, smooth_layer(prev_feat))
  56. if self.p6_feat:
  57. if self.from_c5:
  58. p6_feat = self.p6_conv(feats[0])
  59. else:
  60. p6_feat = self.p6_conv(outputs[-1])
  61. # [P3, P4, P5] -> [P3, P4, P5, P6]
  62. outputs.append(p6_feat)
  63. if self.p7_feat:
  64. p7_feat = self.p7_conv(p6_feat)
  65. # [P3, P4, P5, P6] -> [P3, P4, P5, P6, P7]
  66. outputs.append(p7_feat)
  67. # [P3, P4, P5] or [P3, P4, P5, P6, P7]
  68. return outputs