gelan_neck.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import ConvModule
  5. except:
  6. from modules import ConvModule
  7. # SPP-ELAN (from yolov9)
  8. class SPPElan(nn.Module):
  9. def __init__(self, cfg, in_dim):
  10. """SPPElan looks like the SPPF."""
  11. super().__init__()
  12. ## ----------- Basic Parameters -----------
  13. self.in_dim = in_dim
  14. self.inter_dim = cfg.spp_inter_dim
  15. self.out_dim = cfg.spp_out_dim
  16. ## ----------- Network Parameters -----------
  17. self.conv_layer_1 = ConvModule(in_dim, self.inter_dim, kernel_size=1)
  18. self.conv_layer_2 = ConvModule(self.inter_dim * 4, self.out_dim, kernel_size=1)
  19. self.pool_layer = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
  20. # Initialize all layers
  21. self.init_weights()
  22. def init_weights(self):
  23. """Initialize the parameters."""
  24. for m in self.modules():
  25. if isinstance(m, torch.nn.Conv2d):
  26. m.reset_parameters()
  27. def forward(self, x):
  28. y = [self.conv_layer_1(x)]
  29. y.extend(self.pool_layer(y[-1]) for _ in range(3))
  30. return self.conv_layer_2(torch.cat(y, 1))
  31. if __name__=='__main__':
  32. from thop import profile
  33. class BaseConfig(object):
  34. def __init__(self) -> None:
  35. self.spp_inter_dim = 512
  36. self.spp_out_dim = 512
  37. # 定义模型配置文件
  38. cfg = BaseConfig()
  39. # Build a neck
  40. in_dim = 512
  41. model = SPPElan(cfg, in_dim)
  42. # Randomly generate a input data
  43. x = torch.randn(2, in_dim, 20, 20)
  44. # Inference
  45. output = model(x)
  46. print(' - the shape of input : ', x.shape)
  47. print(' - the shape of output : ', output.shape)
  48. x = torch.randn(1, in_dim, 20, 20)
  49. flops, params = profile(model, inputs=(x, ), verbose=False)
  50. print('============== FLOPs & Params ================')
  51. print(' - FLOPs : {:.2f} G'.format(flops / 1e9 * 2))
  52. print(' - Params : {:.2f} M'.format(params / 1e6))