|
@@ -3,26 +3,29 @@ import torch
|
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
|
- from .basic_modules.basic import MLP
|
|
|
|
|
|
|
+ from .basic_modules.basic import MLP, multiclass_nms
|
|
|
from .basic_modules.transformer import get_clones
|
|
from .basic_modules.transformer import get_clones
|
|
|
from .rtpdetr_encoder import build_image_encoder
|
|
from .rtpdetr_encoder import build_image_encoder
|
|
|
from .rtpdetr_decoder import build_transformer
|
|
from .rtpdetr_decoder import build_transformer
|
|
|
except:
|
|
except:
|
|
|
- from basic_modules.basic import MLP
|
|
|
|
|
|
|
+ from basic_modules.basic import MLP, multiclass_nms
|
|
|
from basic_modules.transformer import get_clones
|
|
from basic_modules.transformer import get_clones
|
|
|
from rtpdetr_encoder import build_image_encoder
|
|
from rtpdetr_encoder import build_image_encoder
|
|
|
from rtpdetr_decoder import build_transformer
|
|
from rtpdetr_decoder import build_transformer
|
|
|
|
|
|
|
|
|
|
|
|
|
-# Real-time Plain Transformer-based Object Detector
|
|
|
|
|
|
|
+# Real-time PlainDETR
|
|
|
class RT_PDETR(nn.Module):
|
|
class RT_PDETR(nn.Module):
|
|
|
def __init__(self,
|
|
def __init__(self,
|
|
|
cfg,
|
|
cfg,
|
|
|
num_classes = 80,
|
|
num_classes = 80,
|
|
|
conf_thresh = 0.1,
|
|
conf_thresh = 0.1,
|
|
|
|
|
+ nms_thresh = 0.5,
|
|
|
topk = 300,
|
|
topk = 300,
|
|
|
deploy = False,
|
|
deploy = False,
|
|
|
no_multi_labels = False,
|
|
no_multi_labels = False,
|
|
|
|
|
+ use_nms = False,
|
|
|
|
|
+ nms_class_agnostic = False,
|
|
|
aux_loss = False,
|
|
aux_loss = False,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
@@ -33,11 +36,15 @@ class RT_PDETR(nn.Module):
|
|
|
self.num_classes = num_classes
|
|
self.num_classes = num_classes
|
|
|
self.num_topk = topk
|
|
self.num_topk = topk
|
|
|
self.aux_loss = aux_loss
|
|
self.aux_loss = aux_loss
|
|
|
- self.conf_thresh = conf_thresh
|
|
|
|
|
- self.no_multi_labels = no_multi_labels
|
|
|
|
|
self.deploy = deploy
|
|
self.deploy = deploy
|
|
|
# scale hidden channels by width_factor
|
|
# scale hidden channels by width_factor
|
|
|
cfg['hidden_dim'] = round(cfg['hidden_dim'] * cfg['width'])
|
|
cfg['hidden_dim'] = round(cfg['hidden_dim'] * cfg['width'])
|
|
|
|
|
+ ## Post-process parameters
|
|
|
|
|
+ self.use_nms = use_nms
|
|
|
|
|
+ self.nms_thresh = nms_thresh
|
|
|
|
|
+ self.conf_thresh = conf_thresh
|
|
|
|
|
+ self.no_multi_labels = no_multi_labels
|
|
|
|
|
+ self.nms_class_agnostic = nms_class_agnostic
|
|
|
|
|
|
|
|
# ----------- Network setting -----------
|
|
# ----------- Network setting -----------
|
|
|
## Image encoder
|
|
## Image encoder
|
|
@@ -162,6 +169,15 @@ class RT_PDETR(nn.Module):
|
|
|
topk_labels = topk_idxs % self.num_classes
|
|
topk_labels = topk_idxs % self.num_classes
|
|
|
topk_bboxes = box_pred[topk_box_idxs]
|
|
topk_bboxes = box_pred[topk_box_idxs]
|
|
|
|
|
|
|
|
|
|
+ topk_scores = topk_scores.cpu().numpy()
|
|
|
|
|
+ topk_labels = topk_labels.cpu().numpy()
|
|
|
|
|
+ topk_bboxes = topk_bboxes.cpu().numpy()
|
|
|
|
|
+
|
|
|
|
|
+ # nms
|
|
|
|
|
+ if self.use_nms:
|
|
|
|
|
+ topk_scores, topk_labels, topk_bboxes = multiclass_nms(
|
|
|
|
|
+ topk_scores, topk_labels, topk_bboxes, self.nms_thresh, self.num_classes, self.nms_class_agnostic)
|
|
|
|
|
+
|
|
|
return topk_bboxes, topk_scores, topk_labels
|
|
return topk_bboxes, topk_scores, topk_labels
|
|
|
|
|
|
|
|
@torch.jit.unused
|
|
@torch.jit.unused
|
|
@@ -225,9 +241,9 @@ class RT_PDETR(nn.Module):
|
|
|
bboxes, scores, labels = self.post_process(box_pred, cls_pred)
|
|
bboxes, scores, labels = self.post_process(box_pred, cls_pred)
|
|
|
|
|
|
|
|
outputs = {
|
|
outputs = {
|
|
|
- "scores": scores.cpu().numpy(),
|
|
|
|
|
- "labels": labels.cpu().numpy(),
|
|
|
|
|
- "bboxes": bboxes.cpu().numpy(),
|
|
|
|
|
|
|
+ "scores": scores,
|
|
|
|
|
+ "labels": labels,
|
|
|
|
|
+ "bboxes": bboxes,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
return outputs
|
|
return outputs
|