build.py 943 B

12345678910111213141516171819202122232425262728293031
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. import torch.nn as nn
  5. from .loss import build_criterion
  6. from .rtrdet import RTRDet
  7. # build object detector
  8. def build_rtrdet(args, cfg, device, num_classes=80, trainable=False, deploy=False):
  9. print('==============================')
  10. print('Build {} ...'.format(args.model.upper()))
  11. # -------------- Build RTRDet --------------
  12. model = RTRDet(cfg = cfg,
  13. device = device,
  14. num_classes = num_classes,
  15. trainable = trainable,
  16. aux_loss = True if trainable else False,
  17. deploy = deploy
  18. )
  19. # -------------- Build criterion --------------
  20. criterion = None
  21. if trainable:
  22. # build criterion for training
  23. criterion = build_criterion(cfg, num_classes, aux_loss=True)
  24. return model, criterion