build.py 370 B

123456789101112
  1. from .gelan import gelan_s, gelan_c
  2. def build_gelan(args):
  3. # build vit model
  4. if args.model == 'gelan_s':
  5. model = gelan_s(args.img_dim, args.num_classes)
  6. elif args.model == 'gelan_c':
  7. model = gelan_c(args.img_dim, args.num_classes)
  8. else:
  9. raise NotImplementedError("Unknown elannet: {}".format(args.model))
  10. return model