| 1234567891011121314 |
- from .mlp import MLP
- def build_mlp(args):
- if args.model == "mlp":
- model = MLP(in_dim = args.mlp_in_dim,
- inter_dim = 1024,
- out_dim = args.num_classes,
- act_type = "relu",
- norm_type = "bn")
-
- else:
- raise NotImplementedError("Unknown model: {}".format(args.model))
-
- return model
|