modules.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import torch
  2. import torch.nn as nn
  3. # --------------------- Basic modules ---------------------
  4. class ConvModule(nn.Module):
  5. def __init__(self,
  6. in_dim: int, # in channels
  7. out_dim: int, # out channels
  8. kernel_size: int = 1, # kernel size
  9. stride:int = 1, # padding
  10. ):
  11. super(ConvModule, self).__init__()
  12. convs = []
  13. convs.append(nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size//2, stride=stride, bias=False))
  14. convs.append(nn.BatchNorm2d(out_dim))
  15. convs.append(nn.SiLU(inplace=True))
  16. self.convs = nn.Sequential(*convs)
  17. def forward(self, x):
  18. return self.convs(x)
  19. class Bottleneck(nn.Module):
  20. def __init__(self,
  21. in_dim: int,
  22. out_dim: int,
  23. expand_ratio: float = 0.5,
  24. shortcut: bool = False,
  25. ):
  26. super(Bottleneck, self).__init__()
  27. inter_dim = int(out_dim * expand_ratio) # hidden channels
  28. self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1)
  29. self.cv2 = ConvModule(inter_dim, out_dim, kernel_size=3, stride=1)
  30. self.shortcut = shortcut and in_dim == out_dim
  31. def forward(self, x):
  32. h = self.cv2(self.cv1(x))
  33. return x + h if self.shortcut else h
  34. class ResBlock(nn.Module):
  35. def __init__(self,
  36. in_dim: int,
  37. out_dim: int,
  38. num_blocks: int = 1,
  39. ):
  40. super(ResBlock, self).__init__()
  41. assert in_dim == out_dim
  42. self.m = nn.Sequential(*[
  43. Bottleneck(in_dim, out_dim, expand_ratio=0.5, shortcut=True)
  44. for _ in range(num_blocks)
  45. ])
  46. def forward(self, x):
  47. return self.m(x)
  48. class ConvBlocks(nn.Module):
  49. def __init__(self, in_dim: int, out_dim: int):
  50. super().__init__()
  51. inter_dim = out_dim // 2
  52. self.convs = nn.Sequential(
  53. ConvModule(in_dim, out_dim, kernel_size=1),
  54. ConvModule(out_dim, inter_dim, kernel_size=3, stride=1),
  55. ConvModule(inter_dim, out_dim, kernel_size=1),
  56. ConvModule(out_dim, inter_dim, kernel_size=3, stride=1),
  57. ConvModule(inter_dim, out_dim, kernel_size=1)
  58. )
  59. def forward(self, x):
  60. return self.convs(x)