Răsfoiți Sursa

use nms for RT-DETR

yjh0410 1 an în urmă
părinte
comite
63a2f44fb2

+ 1 - 0
models/detectors/rtdetr/basic_modules/basic.py

@@ -72,6 +72,7 @@ def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnost
     else:
         return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
 
+
 # ----------------- MLP modules -----------------
 class MLP(nn.Module):
     def __init__(self, in_dim, hidden_dim, out_dim, num_layers):

+ 1 - 1
models/detectors/rtdetr/build.py

@@ -24,7 +24,7 @@ def build_rtdetr(args, cfg, num_classes=80, trainable=False, deploy=False):
                     topk            = 300,
                     deploy          = deploy,
                     no_multi_labels = args.no_multi_labels,
-                    use_nms         = True,
+                    use_nms         = True,   # NMS is beneficial 
                     nms_class_agnostic = args.nms_class_agnostic
                     )
             

+ 1 - 4
models/detectors/rtdetr/rtdetr.py

@@ -11,7 +11,7 @@ except:
     from  rtdetr_decoder import build_transformer
 
 
-# Real-time Transformer-based Object Detector
+# Real-time DETR
 class RT_DETR(nn.Module):
     def __init__(self,
                  cfg,
@@ -28,8 +28,6 @@ class RT_DETR(nn.Module):
         # ----------- Basic setting -----------
         self.num_classes = num_classes
         self.num_topk = topk
-        self.conf_thresh = conf_thresh
-        self.no_multi_labels = no_multi_labels
         self.deploy = deploy
         # scale hidden channels by width_factor
         cfg['hidden_dim'] = round(cfg['hidden_dim'] * cfg['width'])
@@ -37,7 +35,6 @@ class RT_DETR(nn.Module):
         self.use_nms = use_nms
         self.nms_thresh = nms_thresh
         self.conf_thresh = conf_thresh
-        self.topk_candidates = topk
         self.no_multi_labels = no_multi_labels
         self.nms_class_agnostic = nms_class_agnostic
 

+ 71 - 0
models/detectors/rtpdetr/basic_modules/basic.py

@@ -1,5 +1,6 @@
 import math
 import warnings
+import numpy as np
 import torch
 import torch.nn as nn
 
@@ -72,6 +73,76 @@ def delta2bbox(proposals,
     return bboxes
 
 
+# ---------------------------- NMS ----------------------------
+## basic NMS
+def nms(bboxes, scores, nms_thresh):
+    """"Pure Python NMS."""
+    x1 = bboxes[:, 0]  #xmin
+    y1 = bboxes[:, 1]  #ymin
+    x2 = bboxes[:, 2]  #xmax
+    y2 = bboxes[:, 3]  #ymax
+
+    areas = (x2 - x1) * (y2 - y1)
+    order = scores.argsort()[::-1]
+
+    keep = []
+    while order.size > 0:
+        i = order[0]
+        keep.append(i)
+        # compute iou
+        xx1 = np.maximum(x1[i], x1[order[1:]])
+        yy1 = np.maximum(y1[i], y1[order[1:]])
+        xx2 = np.minimum(x2[i], x2[order[1:]])
+        yy2 = np.minimum(y2[i], y2[order[1:]])
+
+        w = np.maximum(1e-10, xx2 - xx1)
+        h = np.maximum(1e-10, yy2 - yy1)
+        inter = w * h
+
+        iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
+        #reserve all the boundingbox whose ovr less than thresh
+        inds = np.where(iou <= nms_thresh)[0]
+        order = order[inds + 1]
+
+    return keep
+
+## class-agnostic NMS 
+def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
+    # nms
+    keep = nms(bboxes, scores, nms_thresh)
+    scores = scores[keep]
+    labels = labels[keep]
+    bboxes = bboxes[keep]
+
+    return scores, labels, bboxes
+
+## class-aware NMS 
+def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
+    # nms
+    keep = np.zeros(len(bboxes), dtype=np.int32)
+    for i in range(num_classes):
+        inds = np.where(labels == i)[0]
+        if len(inds) == 0:
+            continue
+        c_bboxes = bboxes[inds]
+        c_scores = scores[inds]
+        c_keep = nms(c_bboxes, c_scores, nms_thresh)
+        keep[inds[c_keep]] = 1
+    keep = np.where(keep > 0)
+    scores = scores[keep]
+    labels = labels[keep]
+    bboxes = bboxes[keep]
+
+    return scores, labels, bboxes
+
+## multi-class NMS 
+def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
+    if class_agnostic:
+        return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
+    else:
+        return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
+
+
 # ----------------- Customed NormLayer Ops -----------------
 class FrozenBatchNorm2d(torch.nn.Module):
     def __init__(self, n):

+ 2 - 0
models/detectors/rtpdetr/build.py

@@ -23,6 +23,8 @@ def build_rtpdetr(args, cfg, num_classes=80, trainable=False, deploy=False):
                      topk            = 300,
                      deploy          = deploy,
                      no_multi_labels = args.no_multi_labels,
+                     use_nms         = True,   # NMS is beneficial 
+                     nms_class_agnostic = args.nms_class_agnostic
                      )
             
     # -------------- Build criterion --------------

+ 24 - 8
models/detectors/rtpdetr/rtpdetr.py

@@ -3,26 +3,29 @@ import torch
 import torch.nn as nn
 
 try:
-    from .basic_modules.basic import MLP
+    from .basic_modules.basic import MLP, multiclass_nms
     from .basic_modules.transformer import get_clones
     from .rtpdetr_encoder import build_image_encoder
     from .rtpdetr_decoder import build_transformer
 except:
-    from  basic_modules.basic import MLP
+    from  basic_modules.basic import MLP, multiclass_nms
     from  basic_modules.transformer import get_clones
     from  rtpdetr_encoder import build_image_encoder
     from  rtpdetr_decoder import build_transformer
 
 
-# Real-time Plain Transformer-based Object Detector
+# Real-time PlainDETR
 class RT_PDETR(nn.Module):
     def __init__(self,
                  cfg,
                  num_classes = 80,
                  conf_thresh = 0.1,
+                 nms_thresh  = 0.5,
                  topk        = 300,
                  deploy      = False,
                  no_multi_labels = False,
+                 use_nms     = False,
+                 nms_class_agnostic = False,
                  aux_loss    = False,
                  ):
         super().__init__()
@@ -33,11 +36,15 @@ class RT_PDETR(nn.Module):
         self.num_classes = num_classes
         self.num_topk = topk
         self.aux_loss = aux_loss
-        self.conf_thresh = conf_thresh
-        self.no_multi_labels = no_multi_labels
         self.deploy = deploy
         # scale hidden channels by width_factor
         cfg['hidden_dim'] = round(cfg['hidden_dim'] * cfg['width'])
+        ## Post-process parameters
+        self.use_nms = use_nms
+        self.nms_thresh = nms_thresh
+        self.conf_thresh = conf_thresh
+        self.no_multi_labels = no_multi_labels
+        self.nms_class_agnostic = nms_class_agnostic
 
         # ----------- Network setting -----------
         ## Image encoder
@@ -162,6 +169,15 @@ class RT_PDETR(nn.Module):
             topk_labels = topk_idxs % self.num_classes
             topk_bboxes = box_pred[topk_box_idxs]
 
+        topk_scores = topk_scores.cpu().numpy()
+        topk_labels = topk_labels.cpu().numpy()
+        topk_bboxes = topk_bboxes.cpu().numpy()
+
+        # nms
+        if self.use_nms:
+            topk_scores, topk_labels, topk_bboxes = multiclass_nms(
+                topk_scores, topk_labels, topk_bboxes, self.nms_thresh, self.num_classes, self.nms_class_agnostic)
+
         return topk_bboxes, topk_scores, topk_labels
     
     @torch.jit.unused
@@ -225,9 +241,9 @@ class RT_PDETR(nn.Module):
         bboxes, scores, labels = self.post_process(box_pred, cls_pred)
 
         outputs = {
-            "scores": scores.cpu().numpy(),
-            "labels": labels.cpu().numpy(),
-            "bboxes": bboxes.cpu().numpy(),
+            "scores": scores,
+            "labels": labels,
+            "bboxes": bboxes,
         }
 
         return outputs