| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748 |
- import torch
- import torch.nn as nn
- def get_activation(act_type=None):
- if act_type == 'sigmoid':
- return nn.Sigmoid()
- elif act_type == 'relu':
- return nn.ReLU(inplace=True)
- elif act_type == 'lrelu':
- return nn.LeakyReLU(0.1, inplace=True)
- elif act_type == 'mish':
- return nn.Mish(inplace=True)
- elif act_type == 'silu':
- return nn.SiLU(inplace=True)
- elif act_type is None:
- return nn.Identity()
- else:
- raise NotImplementedError
-
- def get_norm(norm_type, dim):
- if norm_type == 'bn':
- return nn.BatchNorm1d(dim)
- elif norm_type == 'ln':
- return nn.LayerNorm(dim)
- elif norm_type == 'gn':
- return nn.GroupNorm(num_groups=32, num_channels=dim)
- elif norm_type is None:
- return nn.Identity()
- else:
- raise NotImplementedError
- # Single Layer Perceptron
- class SLP(nn.Module):
- def __init__(self,
- in_dim :int,
- out_dim :int,
- act_type :str = "sigmoid",
- norm_type :str = "bn") -> None:
- super().__init__()
- use_bias = False if norm_type is not None else True
- self.layer = nn.Linear(in_features=in_dim, out_features=out_dim, bias=use_bias)
- self.norm = get_norm(norm_type, out_dim)
- self.act = get_activation(act_type)
- def forward(self, x):
- return self.act(self.norm(self.layer(x)))
|