mlp.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import SLP
  5. except:
  6. from modules import SLP
  7. # Multi Layer Perceptron
  8. class MLP(nn.Module):
  9. def __init__(self,
  10. in_dim :int,
  11. inter_dim :int,
  12. out_dim :int,
  13. act_type :str = "sigmoid",
  14. norm_type :str = "bn") -> None:
  15. super().__init__()
  16. self.stem = SLP(in_dim, inter_dim, act_type, norm_type)
  17. self.layers = nn.Sequential(
  18. SLP(inter_dim, inter_dim, act_type, norm_type),
  19. SLP(inter_dim, inter_dim, act_type, norm_type),
  20. SLP(inter_dim, inter_dim, act_type, norm_type),
  21. SLP(inter_dim, inter_dim, act_type, norm_type),
  22. )
  23. self.fc = nn.Linear(inter_dim, out_dim)
  24. def forward(self, x):
  25. """
  26. Input:
  27. x : (torch.Tensor) -> [B, C, H, W] or [B, C]
  28. """
  29. if len(x.shape) > 2:
  30. x = x.flatten(1)
  31. x = self.stem(x)
  32. x = self.layers(x)
  33. x = self.fc(x)
  34. return x
  35. if __name__ == "__main__":
  36. bs, c = 8, 256
  37. hidden_dim = 512
  38. num_classes = 10
  39. # Make an input data randomly
  40. x = torch.randn(bs, c)
  41. # Build a MLP model
  42. model = MLP(in_dim=c, inter_dim=hidden_dim, out_dim=num_classes, act_type='sigmoid', norm_type='bn')
  43. # Inference
  44. output = model(x)
  45. print(output.shape)