build.py 910 B

123456789101112131415161718192021222324252627282930313233
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. from .loss import build_criterion
  4. from .rtdetr import RTDETR
  5. # build object detector
  6. def build_rtdetr(args, cfg, device, num_classes=80, trainable=False, deploy=False):
  7. print('==============================')
  8. print('Build {} ...'.format(args.model.upper()))
  9. print('==============================')
  10. print('Model Configuration: \n', cfg)
  11. # -------------- Build rtdetr --------------
  12. model = RTDETR(
  13. cfg=cfg,
  14. device=device,
  15. num_classes=num_classes,
  16. trainable=trainable,
  17. aux_loss=trainable,
  18. with_box_refine=True,
  19. deploy=deploy
  20. )
  21. # -------------- Build criterion --------------
  22. criterion = None
  23. if trainable:
  24. # build criterion for training
  25. criterion = build_criterion(cfg, num_classes, aux_loss=True)
  26. return model, criterion