| 123456789101112131415161718 |
- from .cspdarknet import cspdarknet_n, cspdarknet_s, cspdarknet_m, cspdarknet_l, cspdarknet_x
- def build_cspdarknet(args):
- # build vit model
- if args.model == 'cspdarknet_n':
- model = cspdarknet_n(args.img_dim, args.num_classes)
- elif args.model == 'cspdarknet_s':
- model = cspdarknet_s(args.img_dim, args.num_classes)
- elif args.model == 'cspdarknet_m':
- model = cspdarknet_m(args.img_dim, args.num_classes)
- elif args.model == 'cspdarknet_l':
- model = cspdarknet_l(args.img_dim, args.num_classes)
- elif args.model == 'cspdarknet_x':
- model = cspdarknet_x(args.img_dim, args.num_classes)
- else:
- raise NotImplementedError("Unknown cspdarknet: {}".format(args.model))
-
- return model
|