| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- import torch
- import torch.nn as nn
- try:
- from .modules import SLP
- except:
- from modules import SLP
- # Multi Layer Perceptron
- class MLP(nn.Module):
- def __init__(self,
- in_dim :int,
- inter_dim :int,
- out_dim :int,
- act_type :str = "sigmoid",
- norm_type :str = "bn") -> None:
- super().__init__()
- self.stem = SLP(in_dim, inter_dim, act_type, norm_type)
- self.layers = nn.Sequential(
- SLP(inter_dim, inter_dim, act_type, norm_type),
- SLP(inter_dim, inter_dim, act_type, norm_type),
- SLP(inter_dim, inter_dim, act_type, norm_type),
- SLP(inter_dim, inter_dim, act_type, norm_type),
- )
- self.fc = nn.Linear(inter_dim, out_dim)
- def forward(self, x):
- """
- Input:
- x : (torch.Tensor) -> [B, C, H, W] or [B, C]
- """
- if len(x.shape) > 2:
- x = x.flatten(1)
- x = self.stem(x)
- x = self.layers(x)
- x = self.fc(x)
- return x
- if __name__ == "__main__":
- bs, c = 8, 256
- hidden_dim = 512
- num_classes = 10
-
- # Make an input data randomly
- x = torch.randn(bs, c)
- # Build a MLP model
- model = MLP(in_dim=c, inter_dim=hidden_dim, out_dim=num_classes, act_type='sigmoid', norm_type='bn')
- # Inference
- output = model(x)
- print(output.shape)
|