import torch import torch.nn as nn import torch.nn.functional as F # Spatial Pyramid Pooling class SPP(nn.Module): """ Spatial Pyramid Pooling """ def __init__(self): super(SPP, self).__init__() def forward(self, x): """ Input: x: (Tensor) -> [B, C, H, W] Output: y: (Tensor) -> [B, 4C, H, W] """ x_1 = F.max_pool2d(x, 5, stride=1, padding=2) x_2 = F.max_pool2d(x, 9, stride=1, padding=4) x_3 = F.max_pool2d(x, 13, stride=1, padding=6) y = torch.cat([x, x_1, x_2, x_3], dim=1) return y def build_neck(cfg): model = cfg['neck'] print('==============================') print('Neck: {}'.format(model)) # build neck if model == 'spp': neck = SPP() return neck