| 1234567891011121314151617181920212223242526272829303132333435363738 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- # Spatial Pyramid Pooling
- class SPP(nn.Module):
- """
- Spatial Pyramid Pooling
- """
- def __init__(self):
- super(SPP, self).__init__()
- def forward(self, x):
- """
- Input:
- x: (Tensor) -> [B, C, H, W]
- Output:
- y: (Tensor) -> [B, 4C, H, W]
- """
- x_1 = F.max_pool2d(x, 5, stride=1, padding=2)
- x_2 = F.max_pool2d(x, 9, stride=1, padding=4)
- x_3 = F.max_pool2d(x, 13, stride=1, padding=6)
- y = torch.cat([x, x_1, x_2, x_3], dim=1)
- return y
- def build_neck(cfg):
- model = cfg['neck']
- print('==============================')
- print('Neck: {}'.format(model))
- # build neck
- if model == 'spp':
- neck = SPP()
- return neck
-
|