gelan_neck.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import ConvModule
  5. except:
  6. from modules import ConvModule
  7. # SPP-ELAN (from yolov9)
  8. class SPPElan(nn.Module):
  9. def __init__(self, cfg, in_dim):
  10. """SPPElan looks like the SPPF."""
  11. super().__init__()
  12. ## ----------- Basic Parameters -----------
  13. self.in_dim = in_dim
  14. self.inter_dim = cfg.spp_inter_dim
  15. self.out_dim = cfg.spp_out_dim
  16. ## ----------- Network Parameters -----------
  17. self.conv_layer_1 = ConvModule(in_dim, self.inter_dim, kernel_size=1)
  18. self.conv_layer_2 = ConvModule(self.inter_dim * 4, self.out_dim, kernel_size=1)
  19. self.pool_layer = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
  20. # Initialize all layers
  21. self.init_weights()
  22. def init_weights(self):
  23. """Initialize the parameters."""
  24. for m in self.modules():
  25. if isinstance(m, torch.nn.Conv2d):
  26. m.reset_parameters()
  27. def forward(self, x):
  28. y = [self.conv_layer_1(x)]
  29. y.extend(self.pool_layer(y[-1]) for _ in range(3))
  30. return self.conv_layer_2(torch.cat(y, 1))