| 123456789101112131415161718192021222324252627282930313233343536 |
- #!/usr/bin/env python3
- # -*- coding:utf-8 -*-
- import torch
- import torch.nn as nn
- from .loss import build_criterion
- from .rtpdetr import RT_PDETR
- # build object detector
- def build_rtpdetr(args, cfg, num_classes=80, trainable=False, deploy=False):
- print('==============================')
- print('Build {} ...'.format(args.model.upper()))
-
- print('==============================')
- print('Model Configuration: \n', cfg)
-
- # -------------- Build RT-DETR --------------
- model = RT_PDETR(cfg = cfg,
- num_classes = num_classes,
- conf_thresh = args.conf_thresh,
- topk = 300,
- deploy = deploy,
- no_multi_labels = args.no_multi_labels,
- use_nms = True, # NMS is beneficial
- nms_class_agnostic = args.nms_class_agnostic
- )
-
- # -------------- Build criterion --------------
- criterion = None
- if trainable:
- # build criterion for training
- criterion = build_criterion(cfg, num_classes, aux_loss=True)
-
- return model, criterion
|