build.py 938 B

123456789101112131415161718192021222324
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. from .criterion import build_criterion
  4. from .retinanet import RetinaNet
  5. # build RetinaNet
  6. def build_retinanet(cfg, num_classes=80, is_val=False):
  7. # -------------- Build RetinaNet --------------
  8. model = RetinaNet(cfg = cfg,
  9. num_classes = num_classes,
  10. conf_thresh = cfg['train_conf_thresh'] if is_val else cfg['test_conf_thresh'],
  11. nms_thresh = cfg['train_nms_thresh'] if is_val else cfg['test_nms_thresh'],
  12. topk = cfg['train_topk'] if is_val else cfg['test_topk'],
  13. ca_nms = False if is_val else cfg['nms_class_agnostic'])
  14. # -------------- Build Criterion --------------
  15. criterion = None
  16. if is_val:
  17. # build criterion for training
  18. criterion = build_criterion(cfg, num_classes)
  19. return model, criterion