build.py 810 B

12345678910111213141516171819202122232425262728293031
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. from .loss import build_criterion
  4. from .yolov5 import YOLOv5
  5. # build object detector
  6. def build_yolov5(args, cfg, device, num_classes=80, trainable=False):
  7. print('==============================')
  8. print('Build {} ...'.format(args.model.upper()))
  9. print('==============================')
  10. print('Model Configuration: \n', cfg)
  11. model = YOLOv5(
  12. cfg = cfg,
  13. device = device,
  14. num_classes = num_classes,
  15. conf_thresh = args.conf_thresh,
  16. nms_thresh = args.nms_thresh,
  17. topk = args.topk,
  18. trainable = trainable
  19. )
  20. criterion = None
  21. if trainable:
  22. # build criterion for training
  23. criterion = build_criterion(cfg, device, num_classes)
  24. return model, criterion