fpn.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. from utils import weight_init
  4. # ------------------ Basic Feature Pyramid Network ------------------
  5. class BasicFPN(nn.Module):
  6. def __init__(self, cfg,
  7. in_dims=[512, 1024, 2048],
  8. out_dim=256,
  9. ):
  10. super().__init__()
  11. # ------------------ Basic parameters -------------------
  12. self.p6_feat = cfg.fpn_p6_feat
  13. self.p7_feat = cfg.fpn_p7_feat
  14. self.from_c5 = cfg.fpn_p6_from_c5
  15. # ------------------ Network parameters -------------------
  16. ## latter layers
  17. self.input_projs = nn.ModuleList()
  18. self.smooth_layers = nn.ModuleList()
  19. for in_dim in in_dims[::-1]:
  20. self.input_projs.append(nn.Conv2d(in_dim, out_dim, kernel_size=1))
  21. self.smooth_layers.append(nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1))
  22. ## P6/P7 layers
  23. if self.p6_feat:
  24. if self.from_c5:
  25. self.p6_conv = nn.Conv2d(in_dims[-1], out_dim, kernel_size=3, stride=2, padding=1)
  26. else: # from p5
  27. self.p6_conv = nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1)
  28. if self.p7_feat:
  29. self.p7_conv = nn.Sequential(
  30. nn.ReLU(inplace=True),
  31. nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1)
  32. )
  33. self._init_weight()
  34. def _init_weight(self):
  35. for m in self.modules():
  36. if isinstance(m, nn.Conv2d):
  37. weight_init.c2_xavier_fill(m)
  38. def forward(self, feats):
  39. """
  40. feats: (List of Tensor) [C3, C4, C5], C_i ∈ R^(B x C_i x H_i x W_i)
  41. """
  42. outputs = []
  43. # [C3, C4, C5] -> [C5, C4, C3]
  44. feats = feats[::-1]
  45. top_level_feat = feats[0]
  46. prev_feat = self.input_projs[0](top_level_feat)
  47. outputs.append(self.smooth_layers[0](prev_feat))
  48. for feat, input_proj, smooth_layer in zip(feats[1:], self.input_projs[1:], self.smooth_layers[1:]):
  49. feat = input_proj(feat)
  50. top_down_feat = F.interpolate(prev_feat, size=feat.shape[2:], mode='nearest')
  51. prev_feat = feat + top_down_feat
  52. outputs.insert(0, smooth_layer(prev_feat))
  53. if self.p6_feat:
  54. if self.from_c5:
  55. p6_feat = self.p6_conv(feats[0])
  56. else:
  57. p6_feat = self.p6_conv(outputs[-1])
  58. # [P3, P4, P5] -> [P3, P4, P5, P6]
  59. outputs.append(p6_feat)
  60. if self.p7_feat:
  61. p7_feat = self.p7_conv(p6_feat)
  62. # [P3, P4, P5, P6] -> [P3, P4, P5, P6, P7]
  63. outputs.append(p7_feat)
  64. # [P3, P4, P5] or [P3, P4, P5, P6, P7]
  65. return outputs