import torch.nn as nn import torch.nn.functional as F from utils import weight_init # ------------------ Basic Feature Pyramid Network ------------------ class BasicFPN(nn.Module): def __init__(self, cfg, in_dims=[512, 1024, 2048], out_dim=256, ): super().__init__() # ------------------ Basic parameters ------------------- self.p6_feat = cfg.fpn_p6_feat self.p7_feat = cfg.fpn_p7_feat self.from_c5 = cfg.fpn_p6_from_c5 # ------------------ Network parameters ------------------- ## latter layers self.input_projs = nn.ModuleList() self.smooth_layers = nn.ModuleList() for in_dim in in_dims[::-1]: self.input_projs.append(nn.Conv2d(in_dim, out_dim, kernel_size=1)) self.smooth_layers.append(nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)) ## P6/P7 layers if self.p6_feat: if self.from_c5: self.p6_conv = nn.Conv2d(in_dims[-1], out_dim, kernel_size=3, stride=2, padding=1) else: # from p5 self.p6_conv = nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1) if self.p7_feat: self.p7_conv = nn.Sequential( nn.ReLU(inplace=True), nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1) ) self._init_weight() def _init_weight(self): for m in self.modules(): if isinstance(m, nn.Conv2d): weight_init.c2_xavier_fill(m) def forward(self, feats): """ feats: (List of Tensor) [C3, C4, C5], C_i ∈ R^(B x C_i x H_i x W_i) """ outputs = [] # [C3, C4, C5] -> [C5, C4, C3] feats = feats[::-1] top_level_feat = feats[0] prev_feat = self.input_projs[0](top_level_feat) outputs.append(self.smooth_layers[0](prev_feat)) for feat, input_proj, smooth_layer in zip(feats[1:], self.input_projs[1:], self.smooth_layers[1:]): feat = input_proj(feat) top_down_feat = F.interpolate(prev_feat, size=feat.shape[2:], mode='nearest') prev_feat = feat + top_down_feat outputs.insert(0, smooth_layer(prev_feat)) if self.p6_feat: if self.from_c5: p6_feat = self.p6_conv(feats[0]) else: p6_feat = self.p6_conv(outputs[-1]) # [P3, P4, P5] -> [P3, P4, P5, P6] outputs.append(p6_feat) if self.p7_feat: p7_feat = self.p7_conv(p6_feat) # [P3, P4, P5, P6] -> [P3, P4, P5, P6, P7] outputs.append(p7_feat) # [P3, P4, P5] or [P3, P4, P5, P6, P7] return outputs