modules.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. # --------------------- Basic modules ---------------------
  5. class ConvModule(nn.Module):
  6. def __init__(self,
  7. in_dim: int, # in channels
  8. out_dim: int, # out channels
  9. kernel_size: int = 1, # kernel size
  10. stride:int = 1, # padding
  11. ):
  12. super(ConvModule, self).__init__()
  13. convs = []
  14. convs.append(nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size//2, stride=stride, bias=False))
  15. convs.append(nn.BatchNorm2d(out_dim))
  16. convs.append(nn.SiLU(inplace=True))
  17. self.convs = nn.Sequential(*convs)
  18. def forward(self, x):
  19. return self.convs(x)
  20. class ELANBlock(nn.Module):
  21. def __init__(self,
  22. in_dim: int,
  23. out_dim: int,
  24. expansion: float = 0.5,
  25. branch_depth: int = 2,
  26. ):
  27. super(ELANBlock, self).__init__()
  28. inter_dim = int(in_dim * expansion)
  29. self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1)
  30. self.cv2 = ConvModule(in_dim, inter_dim, kernel_size=1)
  31. self.cv3 = nn.Sequential(*[ConvModule(inter_dim, inter_dim, kernel_size=3)
  32. for _ in range(round(branch_depth))
  33. ])
  34. self.cv4 = nn.Sequential(*[ConvModule(inter_dim, inter_dim, kernel_size=3)
  35. for _ in range(round(branch_depth))
  36. ])
  37. self.out = ConvModule(inter_dim*4, out_dim, kernel_size=1)
  38. def forward(self, x):
  39. x1 = self.cv1(x)
  40. x2 = self.cv2(x)
  41. x3 = self.cv3(x2)
  42. x4 = self.cv4(x3)
  43. out = self.out(torch.cat([x1, x2, x3, x4], dim=1))
  44. return out
  45. class ELANBlockFPN(nn.Module):
  46. def __init__(self,
  47. in_dim: int,
  48. out_dim: int,
  49. expansion: float = 0.5,
  50. branch_width: int = 4,
  51. branch_depth: int = 1,
  52. ):
  53. super(ELANBlockFPN, self).__init__()
  54. # Basic parameters
  55. inter_dim = int(in_dim * expansion)
  56. inter_dim2 = int(inter_dim * expansion)
  57. # Network structure
  58. self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1)
  59. self.cv2 = ConvModule(in_dim, inter_dim, kernel_size=1)
  60. self.cv3 = nn.ModuleList()
  61. for idx in range(round(branch_width)):
  62. if idx == 0:
  63. cvs = [ConvModule(inter_dim, inter_dim2, kernel_size=3)]
  64. else:
  65. cvs = [ConvModule(inter_dim2, inter_dim2, kernel_size=3)]
  66. # deeper
  67. if round(branch_depth) > 1:
  68. for _ in range(1, round(branch_depth)):
  69. cvs.append(ConvModule(inter_dim2, inter_dim2, kernel_size=3))
  70. self.cv3.append(nn.Sequential(*cvs))
  71. else:
  72. self.cv3.append(cvs[0])
  73. self.out = ConvModule(inter_dim*2 + inter_dim2*len(self.cv3), out_dim, kernel_size=1)
  74. def forward(self, x):
  75. x1 = self.cv1(x)
  76. x2 = self.cv2(x)
  77. inter_outs = [x1, x2]
  78. for m in self.cv3:
  79. y1 = inter_outs[-1]
  80. y2 = m(y1)
  81. inter_outs.append(y2)
  82. out = self.out(torch.cat(inter_outs, dim=1))
  83. return out
  84. class DownSample(nn.Module):
  85. def __init__(self, in_dim, out_dim):
  86. super().__init__()
  87. inter_dim = out_dim // 2
  88. self.mp = nn.MaxPool2d((2, 2), 2)
  89. self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1)
  90. self.cv2 = nn.Sequential(
  91. ConvModule(in_dim, inter_dim, kernel_size=1),
  92. ConvModule(inter_dim, inter_dim, kernel_size=3, stride=2)
  93. )
  94. def forward(self, x):
  95. x1 = self.cv1(self.mp(x))
  96. x2 = self.cv2(x)
  97. out = torch.cat([x1, x2], dim=1)
  98. return out