build.py 755 B

123456789101112131415161718
  1. from .cspdarknet import cspdarknet_n, cspdarknet_s, cspdarknet_m, cspdarknet_l, cspdarknet_x
  2. def build_cspdarknet(args):
  3. # build vit model
  4. if args.model == 'cspdarknet_n':
  5. model = cspdarknet_n(args.img_dim, args.num_classes)
  6. elif args.model == 'cspdarknet_s':
  7. model = cspdarknet_s(args.img_dim, args.num_classes)
  8. elif args.model == 'cspdarknet_m':
  9. model = cspdarknet_m(args.img_dim, args.num_classes)
  10. elif args.model == 'cspdarknet_l':
  11. model = cspdarknet_l(args.img_dim, args.num_classes)
  12. elif args.model == 'cspdarknet_x':
  13. model = cspdarknet_x(args.img_dim, args.num_classes)
  14. else:
  15. raise NotImplementedError("Unknown cspdarknet: {}".format(args.model))
  16. return model