modules.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import torch
  2. import torch.nn as nn
  3. def get_activation(act_type=None):
  4. if act_type == 'sigmoid':
  5. return nn.Sigmoid()
  6. elif act_type == 'relu':
  7. return nn.ReLU(inplace=True)
  8. elif act_type == 'lrelu':
  9. return nn.LeakyReLU(0.1, inplace=True)
  10. elif act_type == 'mish':
  11. return nn.Mish(inplace=True)
  12. elif act_type == 'silu':
  13. return nn.SiLU(inplace=True)
  14. elif act_type is None:
  15. return nn.Identity()
  16. else:
  17. raise NotImplementedError
  18. def get_norm(norm_type, dim):
  19. if norm_type == 'bn':
  20. return nn.BatchNorm1d(dim)
  21. elif norm_type == 'ln':
  22. return nn.LayerNorm(dim)
  23. elif norm_type == 'gn':
  24. return nn.GroupNorm(num_groups=32, num_channels=dim)
  25. elif norm_type is None:
  26. return nn.Identity()
  27. else:
  28. raise NotImplementedError
  29. # Single Layer Perceptron
  30. class SLP(nn.Module):
  31. def __init__(self,
  32. in_dim :int,
  33. out_dim :int,
  34. act_type :str = "sigmoid",
  35. norm_type :str = "bn") -> None:
  36. super().__init__()
  37. use_bias = False if norm_type is not None else True
  38. self.layer = nn.Linear(in_features=in_dim, out_features=out_dim, bias=use_bias)
  39. self.norm = get_norm(norm_type, out_dim)
  40. self.act = get_activation(act_type)
  41. def forward(self, x):
  42. return self.act(self.norm(self.layer(x)))