浏览代码

debug nms for RT-DETR

yjh0410 1 年之前
父节点
当前提交
2c5307f71c
共有 3 个文件被更改,包括 97 次插入3 次删除
  1. 70 0
      models/detectors/rtdetr/basic_modules/basic.py
  2. 3 0
      models/detectors/rtdetr/build.py
  3. 24 3
      models/detectors/rtdetr/rtdetr.py

+ 70 - 0
models/detectors/rtdetr/basic_modules/basic.py

@@ -1,7 +1,77 @@
+import numpy as np
 import torch
 import torch.nn as nn
 
 
+# ---------------------------- NMS ----------------------------
+## basic NMS
+def nms(bboxes, scores, nms_thresh):
+    """"Pure Python NMS."""
+    x1 = bboxes[:, 0]  #xmin
+    y1 = bboxes[:, 1]  #ymin
+    x2 = bboxes[:, 2]  #xmax
+    y2 = bboxes[:, 3]  #ymax
+
+    areas = (x2 - x1) * (y2 - y1)
+    order = scores.argsort()[::-1]
+
+    keep = []
+    while order.size > 0:
+        i = order[0]
+        keep.append(i)
+        # compute iou
+        xx1 = np.maximum(x1[i], x1[order[1:]])
+        yy1 = np.maximum(y1[i], y1[order[1:]])
+        xx2 = np.minimum(x2[i], x2[order[1:]])
+        yy2 = np.minimum(y2[i], y2[order[1:]])
+
+        w = np.maximum(1e-10, xx2 - xx1)
+        h = np.maximum(1e-10, yy2 - yy1)
+        inter = w * h
+
+        iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
+        #reserve all the boundingbox whose ovr less than thresh
+        inds = np.where(iou <= nms_thresh)[0]
+        order = order[inds + 1]
+
+    return keep
+
+## class-agnostic NMS 
+def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
+    # nms
+    keep = nms(bboxes, scores, nms_thresh)
+    scores = scores[keep]
+    labels = labels[keep]
+    bboxes = bboxes[keep]
+
+    return scores, labels, bboxes
+
+## class-aware NMS 
+def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
+    # nms
+    keep = np.zeros(len(bboxes), dtype=np.int32)
+    for i in range(num_classes):
+        inds = np.where(labels == i)[0]
+        if len(inds) == 0:
+            continue
+        c_bboxes = bboxes[inds]
+        c_scores = scores[inds]
+        c_keep = nms(c_bboxes, c_scores, nms_thresh)
+        keep[inds[c_keep]] = 1
+    keep = np.where(keep > 0)
+    scores = scores[keep]
+    labels = labels[keep]
+    bboxes = bboxes[keep]
+
+    return scores, labels, bboxes
+
+## multi-class NMS 
+def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
+    if class_agnostic:
+        return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
+    else:
+        return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
+
 # ----------------- MLP modules -----------------
 class MLP(nn.Module):
     def __init__(self, in_dim, hidden_dim, out_dim, num_layers):

+ 3 - 0
models/detectors/rtdetr/build.py

@@ -19,10 +19,13 @@ def build_rtdetr(args, cfg, num_classes=80, trainable=False, deploy=False):
     # -------------- Build RT-DETR --------------
     model = RT_DETR(cfg             = cfg,
                     num_classes     = num_classes,
+                    nms_thresh      = args.nms_thresh,
                     conf_thresh     = args.conf_thresh,
                     topk            = 300,
                     deploy          = deploy,
                     no_multi_labels = args.no_multi_labels,
+                    use_nms         = True,
+                    nms_class_agnostic = args.nms_class_agnostic
                     )
             
     # -------------- Build criterion --------------

+ 24 - 3
models/detectors/rtdetr/rtdetr.py

@@ -2,9 +2,11 @@ import torch
 import torch.nn as nn
 
 try:
+    from .basic_modules.basic import multiclass_nms
     from .rtdetr_encoder import build_image_encoder
     from .rtdetr_decoder import build_transformer
 except:
+    from .basic_modules.basic import multiclass_nms
     from  rtdetr_encoder import build_image_encoder
     from  rtdetr_decoder import build_transformer
 
@@ -15,9 +17,12 @@ class RT_DETR(nn.Module):
                  cfg,
                  num_classes = 80,
                  conf_thresh = 0.1,
+                 nms_thresh  = 0.5,
                  topk        = 300,
                  deploy      = False,
                  no_multi_labels = False,
+                 use_nms     = False,
+                 nms_class_agnostic = False,
                  ):
         super().__init__()
         # ----------- Basic setting -----------
@@ -28,6 +33,13 @@ class RT_DETR(nn.Module):
         self.deploy = deploy
         # scale hidden channels by width_factor
         cfg['hidden_dim'] = round(cfg['hidden_dim'] * cfg['width'])
+        ## Post-process parameters
+        self.use_nms = use_nms
+        self.nms_thresh = nms_thresh
+        self.conf_thresh = conf_thresh
+        self.topk_candidates = topk
+        self.no_multi_labels = no_multi_labels
+        self.nms_class_agnostic = nms_class_agnostic
 
         # ----------- Network setting -----------
         ## Image encoder
@@ -89,6 +101,15 @@ class RT_DETR(nn.Module):
             topk_labels = topk_idxs % self.num_classes
             topk_bboxes = box_pred[topk_box_idxs]
 
+        topk_scores = topk_scores.cpu().numpy()
+        topk_labels = topk_labels.cpu().numpy()
+        topk_bboxes = topk_bboxes.cpu().numpy()
+
+        # nms
+        if self.use_nms:
+            topk_scores, topk_labels, topk_bboxes = multiclass_nms(
+                topk_scores, topk_labels, topk_bboxes, self.nms_thresh, self.num_classes, self.nms_class_agnostic)
+
         return topk_bboxes, topk_scores, topk_labels
     
     def forward(self, x, targets=None):
@@ -114,9 +135,9 @@ class RT_DETR(nn.Module):
             bboxes, scores, labels = self.post_process(box_pred, cls_pred)
 
             outputs = {
-                "scores": scores.cpu().numpy(),
-                "labels": labels.cpu().numpy(),
-                "bboxes": bboxes.cpu().numpy(),
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes,
             }
 
             return outputs