__init__.py 620 B

1234567891011121314151617181920
  1. from .mlp.build import build_mlp
  2. from .convnet.build import build_convnet
  3. from .resnet.build import build_resnet
  4. from .vit.build import build_vit
  5. def build_model(args):
  6. # --------------------------- ResNet series ---------------------------
  7. if 'mlp' in args.model:
  8. model = build_mlp(args)
  9. elif 'convnet' in args.model:
  10. model = build_convnet(args)
  11. elif 'resnet' in args.model:
  12. model = build_resnet(args)
  13. elif 'vit' in args.model:
  14. model = build_vit(args)
  15. else:
  16. raise NotImplementedError("Unknown model: {}".format(args.model))
  17. return model