yolov1_neck.py 827 B

1234567891011121314151617181920212223242526272829303132333435363738
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. # Spatial Pyramid Pooling
  5. class SPP(nn.Module):
  6. """
  7. Spatial Pyramid Pooling
  8. """
  9. def __init__(self):
  10. super(SPP, self).__init__()
  11. def forward(self, x):
  12. """
  13. Input:
  14. x: (Tensor) -> [B, C, H, W]
  15. Output:
  16. y: (Tensor) -> [B, 4C, H, W]
  17. """
  18. x_1 = F.max_pool2d(x, 5, stride=1, padding=2)
  19. x_2 = F.max_pool2d(x, 9, stride=1, padding=4)
  20. x_3 = F.max_pool2d(x, 13, stride=1, padding=6)
  21. y = torch.cat([x, x_1, x_2, x_3], dim=1)
  22. return y
  23. def build_neck(cfg):
  24. model = cfg['neck']
  25. print('==============================')
  26. print('Neck: {}'.format(model))
  27. # build neck
  28. if model == 'spp':
  29. neck = SPP()
  30. return neck