modules.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import torch
  2. import torch.nn as nn
  3. # --------------------- Basic modules ---------------------
  4. class ConvModule(nn.Module):
  5. def __init__(self,
  6. in_dim: int, # in channels
  7. out_dim: int, # out channels
  8. kernel_size: int = 1, # kernel size
  9. stride:int = 1, # padding
  10. ):
  11. super(ConvModule, self).__init__()
  12. convs = []
  13. convs.append(nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size//2, stride=stride, bias=False))
  14. convs.append(nn.BatchNorm2d(out_dim))
  15. convs.append(nn.SiLU(inplace=True))
  16. self.convs = nn.Sequential(*convs)
  17. def forward(self, x):
  18. return self.convs(x)
  19. class Bottleneck(nn.Module):
  20. def __init__(self,
  21. in_dim: int,
  22. out_dim: int,
  23. expand_ratio: float = 0.5,
  24. shortcut: bool = False,
  25. ):
  26. super(Bottleneck, self).__init__()
  27. inter_dim = int(out_dim * expand_ratio) # hidden channels
  28. self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1)
  29. self.cv2 = ConvModule(inter_dim, out_dim, kernel_size=3, stride=1)
  30. self.shortcut = shortcut and in_dim == out_dim
  31. def forward(self, x):
  32. h = self.cv2(self.cv1(x))
  33. return x + h if self.shortcut else h
  34. class ResBlock(nn.Module):
  35. def __init__(self,
  36. in_dim: int,
  37. out_dim: int,
  38. num_blocks: int = 1,
  39. ):
  40. super(ResBlock, self).__init__()
  41. assert in_dim == out_dim
  42. blocks = []
  43. for i in range(num_blocks):
  44. if i == 0:
  45. blocks.append(Bottleneck(in_dim, out_dim, expand_ratio=0.5, shortcut=True))
  46. else:
  47. blocks.append(Bottleneck(out_dim, out_dim, expand_ratio=0.5, shortcut=True))
  48. self.m = nn.Sequential(*blocks)
  49. def forward(self, x):
  50. return self.m(x)
  51. class ConvBlocks(nn.Module):
  52. def __init__(self, in_dim: int, out_dim: int):
  53. super().__init__()
  54. inter_dim = out_dim // 2
  55. self.convs = nn.Sequential(
  56. ConvModule(in_dim, out_dim, kernel_size=1),
  57. ConvModule(out_dim, inter_dim, kernel_size=3, stride=1),
  58. ConvModule(inter_dim, out_dim, kernel_size=1),
  59. ConvModule(out_dim, inter_dim, kernel_size=3, stride=1),
  60. ConvModule(inter_dim, out_dim, kernel_size=1)
  61. )
  62. def forward(self, x):
  63. return self.convs(x)