modules.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import torch
  2. import torch.nn as nn
  3. def get_activation(act_type=None):
  4. if act_type == 'sigmoid':
  5. return nn.Sigmoid()
  6. elif act_type == 'relu':
  7. return nn.ReLU(inplace=True)
  8. elif act_type == 'lrelu':
  9. return nn.LeakyReLU(0.1, inplace=True)
  10. elif act_type == 'mish':
  11. return nn.Mish(inplace=True)
  12. elif act_type == 'silu':
  13. return nn.SiLU(inplace=True)
  14. elif act_type is None:
  15. return nn.Identity()
  16. else:
  17. raise NotImplementedError
  18. def get_norm(norm_type, dim):
  19. if norm_type == 'bn':
  20. return nn.BatchNorm2d(dim)
  21. elif norm_type == 'ln':
  22. return LayerNorm2d(dim)
  23. elif norm_type == 'gn':
  24. return nn.GroupNorm(num_groups=32, num_channels=dim)
  25. elif norm_type is None:
  26. return nn.Identity()
  27. else:
  28. raise NotImplementedError
  29. class LayerNorm2d(nn.Module):
  30. def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
  31. super().__init__()
  32. self.weight = nn.Parameter(torch.ones(num_channels))
  33. self.bias = nn.Parameter(torch.zeros(num_channels))
  34. self.eps = eps
  35. def forward(self, x: torch.Tensor) -> torch.Tensor:
  36. u = x.mean(1, keepdim=True)
  37. s = (x - u).pow(2).mean(1, keepdim=True)
  38. x = (x - u) / torch.sqrt(s + self.eps)
  39. x = self.weight[:, None, None] * x + self.bias[:, None, None]
  40. return x
  41. # Basic convolutional module
  42. class ConvModule(nn.Module):
  43. def __init__(self,
  44. in_dim :int,
  45. out_dim :int,
  46. kernel_size :int = 1,
  47. padding :int = 0,
  48. stride :int = 1,
  49. act_type :str = "relu",
  50. norm_type :str = "bn",
  51. depthwise :bool = False) -> None:
  52. super().__init__()
  53. use_bias = False if norm_type is not None else True
  54. self.depthwise = depthwise
  55. if not depthwise:
  56. self.conv = nn.Conv2d(in_channels=in_dim, out_channels=out_dim,
  57. kernel_size=kernel_size, padding=padding, stride=stride,
  58. bias=use_bias)
  59. self.norm = get_norm(norm_type, out_dim)
  60. else:
  61. self.conv1 = nn.Conv2d(in_channels=in_dim, out_channels=in_dim,
  62. kernel_size=kernel_size, padding=padding, stride=stride, groups=in_dim,
  63. bias=use_bias)
  64. self.norm1 = get_norm(norm_type, in_dim)
  65. self.conv2 = nn.Conv2d(in_channels=in_dim, out_channels=out_dim,
  66. kernel_size=1, padding=0, stride=1,
  67. bias=use_bias)
  68. self.norm2 = get_norm(norm_type, out_dim)
  69. self.act = get_activation(act_type)
  70. def forward(self, x):
  71. if self.depthwise:
  72. x = self.norm1(self.conv1(x))
  73. x = self.act(self.norm2(self.conv2(x)))
  74. else:
  75. x = self.act(self.norm(self.conv(x)))
  76. return x