__init__.py 512 B

1234567891011121314151617
  1. from .mlp.build import build_mlp
  2. from .convnet.build import build_convnet
  3. from .vit.build import build_vit
  4. def build_model(args):
  5. # --------------------------- ResNet series ---------------------------
  6. if 'mlp' in args.model:
  7. model = build_mlp(args)
  8. elif 'convnet' in args.model:
  9. model = build_convnet(args)
  10. elif 'vit' in args.model:
  11. model = build_vit(args)
  12. else:
  13. raise NotImplementedError("Unknown model: {}".format(args.model))
  14. return model