| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- import torch.nn as nn
- import torch.nn.functional as F
- from utils import weight_init
- # ------------------ Basic Feature Pyramid Network ------------------
- class BasicFPN(nn.Module):
- def __init__(self, cfg,
- in_dims=[512, 1024, 2048],
- out_dim=256,
- ):
- super().__init__()
- # ------------------ Basic parameters -------------------
- self.p6_feat = cfg.fpn_p6_feat
- self.p7_feat = cfg.fpn_p7_feat
- self.from_c5 = cfg.fpn_p6_from_c5
- # ------------------ Network parameters -------------------
- ## latter layers
- self.input_projs = nn.ModuleList()
- self.smooth_layers = nn.ModuleList()
- for in_dim in in_dims[::-1]:
- self.input_projs.append(nn.Conv2d(in_dim, out_dim, kernel_size=1))
- self.smooth_layers.append(nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1))
- ## P6/P7 layers
- if self.p6_feat:
- if self.from_c5:
- self.p6_conv = nn.Conv2d(in_dims[-1], out_dim, kernel_size=3, stride=2, padding=1)
- else: # from p5
- self.p6_conv = nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1)
- if self.p7_feat:
- self.p7_conv = nn.Sequential(
- nn.ReLU(inplace=True),
- nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1)
- )
- self._init_weight()
- def _init_weight(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- weight_init.c2_xavier_fill(m)
- def forward(self, feats):
- """
- feats: (List of Tensor) [C3, C4, C5], C_i ∈ R^(B x C_i x H_i x W_i)
- """
- outputs = []
- # [C3, C4, C5] -> [C5, C4, C3]
- feats = feats[::-1]
- top_level_feat = feats[0]
- prev_feat = self.input_projs[0](top_level_feat)
- outputs.append(self.smooth_layers[0](prev_feat))
- for feat, input_proj, smooth_layer in zip(feats[1:], self.input_projs[1:], self.smooth_layers[1:]):
- feat = input_proj(feat)
- top_down_feat = F.interpolate(prev_feat, size=feat.shape[2:], mode='nearest')
- prev_feat = feat + top_down_feat
- outputs.insert(0, smooth_layer(prev_feat))
- if self.p6_feat:
- if self.from_c5:
- p6_feat = self.p6_conv(feats[0])
- else:
- p6_feat = self.p6_conv(outputs[-1])
- # [P3, P4, P5] -> [P3, P4, P5, P6]
- outputs.append(p6_feat)
- if self.p7_feat:
- p7_feat = self.p7_conv(p6_feat)
- # [P3, P4, P5, P6] -> [P3, P4, P5, P6, P7]
- outputs.append(p7_feat)
- # [P3, P4, P5] or [P3, P4, P5, P6, P7]
- return outputs
|