build.py 406 B

1234567891011121314
  1. from .mlp import MLP
  2. def build_mlp(args):
  3. if args.model == "mlp":
  4. model = MLP(in_dim = args.mlp_in_dim,
  5. inter_dim = 1024,
  6. out_dim = args.num_classes,
  7. act_type = "relu",
  8. norm_type = "bn")
  9. else:
  10. raise NotImplementedError("Unknown model: {}".format(args.model))
  11. return model