yjh0410 преди 2 години
родител
ревизия
ae87755832
променени са 7 файла, в които са добавени 45 реда и са изтрити 27 реда
  1. 4 2
      demo.py
  2. 3 1
      eval.py
  3. 11 11
      models/detectors/yolov8/build.py
  4. 5 3
      models/detectors/yolov8/yolov8.py
  5. 4 2
      test.py
  6. 3 1
      train.py
  7. 15 7
      utils/misc.py

+ 4 - 2
demo.py

@@ -58,8 +58,10 @@ def parse_args():
                         help='confidence threshold')
     parser.add_argument('-nt', '--nms_thresh', default=0.5, type=float,
                         help='NMS threshold')
-    parser.add_argument('--topk', default=100, type=int,
-                        help='topk candidates for testing')
+    parser.add_argument('--topk', default=1000, type=int,
+                        help='topk candidates dets of each level before NMS')
+    parser.add_argument('--max_dets', default=300, type=int,
+                        help='max number of dets after NMS')
     parser.add_argument("--deploy", action="store_true", default=False,
                         help="deploy mode or not")
     parser.add_argument('--fuse_repconv', action='store_true', default=False,

+ 3 - 1
eval.py

@@ -37,7 +37,9 @@ def parse_args():
     parser.add_argument('-nt', '--nms_thresh', default=0.6, type=float,
                         help='NMS threshold')
     parser.add_argument('--topk', default=1000, type=int,
-                        help='topk candidates for testing')
+                        help='topk candidates dets of each level before NMS')
+    parser.add_argument('--max_dets', default=300, type=int,
+                        help='max number of dets after NMS')
     parser.add_argument("--no_decode", action="store_true", default=False,
                         help="not decode in inference or yes")
     parser.add_argument('--fuse_conv_bn', action='store_true', default=False,

+ 11 - 11
models/detectors/yolov8/build.py

@@ -17,17 +17,17 @@ def build_yolov8(args, cfg, device, num_classes=80, trainable=False, deploy=Fals
     print('Model Configuration: \n', cfg)
     
     # -------------- Build YOLO --------------
-    model = YOLOv8(
-        cfg=cfg,
-        device=device, 
-        num_classes=num_classes,
-        trainable=trainable,
-        conf_thresh=args.conf_thresh,
-        nms_thresh=args.nms_thresh,
-        topk=args.topk,
-        deploy=deploy,
-        nms_class_agnostic=args.nms_class_agnostic
-        )
+    model = YOLOv8(cfg                = cfg,
+                   device             = device, 
+                   num_classes        = num_classes,
+                   trainable          = trainable,
+                   conf_thresh        = args.conf_thresh,
+                   nms_thresh         = args.nms_thresh,
+                   topk               = args.topk,
+                   max_dets           = args.max_dets,
+                   deploy             = deploy,
+                   nms_class_agnostic = args.nms_class_agnostic
+                   )
 
     # -------------- Initialize YOLO --------------
     for m in model.modules():

+ 5 - 3
models/detectors/yolov8/yolov8.py

@@ -21,7 +21,8 @@ class YOLOv8(nn.Module):
                  num_classes = 20,
                  conf_thresh = 0.01,
                  nms_thresh  = 0.5,
-                 topk        = 100,
+                 topk        = 1000,
+                 max_dets    = 300,
                  trainable   = False,
                  deploy      = False,
                  nms_class_agnostic = False):
@@ -38,6 +39,7 @@ class YOLOv8(nn.Module):
         self.num_levels = len(self.strides)
         self.num_classes = num_classes
         self.topk = topk
+        self.max_dets = max_dets
         self.deploy = deploy
         self.nms_class_agnostic = nms_class_agnostic
         
@@ -116,8 +118,8 @@ class YOLOv8(nn.Module):
 
         # nms
         scores, labels, bboxes = multiclass_nms(
-            scores, labels, bboxes, self.nms_thresh, self.num_classes, self.nms_class_agnostic)
-
+            scores, labels, bboxes, self.nms_thresh, self.num_classes, self.nms_class_agnostic, self.max_dets)
+        
         return bboxes, scores, labels
 
     # ---------------------- Main Process for Inference ----------------------

+ 4 - 2
test.py

@@ -47,8 +47,10 @@ def parse_args():
                         help='confidence threshold')
     parser.add_argument('-nt', '--nms_thresh', default=0.5, type=float,
                         help='NMS threshold')
-    parser.add_argument('--topk', default=100, type=int,
-                        help='topk candidates for testing')
+    parser.add_argument('--topk', default=1000, type=int,
+                        help='topk candidates dets of each level before NMS')
+    parser.add_argument('--max_dets', default=300, type=int,
+                        help='max number of dets after NMS')
     parser.add_argument("--no_decode", action="store_true", default=False,
                         help="not decode in inference or yes")
     parser.add_argument('--fuse_conv_bn', action='store_true', default=False,

+ 3 - 1
train.py

@@ -66,7 +66,9 @@ def parse_args():
     parser.add_argument('-nt', '--nms_thresh', default=0.6, type=float,
                         help='NMS threshold')
     parser.add_argument('--topk', default=1000, type=int,
-                        help='topk candidates for evaluation')
+                        help='topk candidates dets of each level before NMS')
+    parser.add_argument('--max_dets', default=300, type=int,
+                        help='max number of dets after NMS')
     parser.add_argument('-p', '--pretrained', default=None, type=str,
                         help='load pretrained weight')
     parser.add_argument('-r', '--resume', default=None, type=str,

+ 15 - 7
utils/misc.py

@@ -290,18 +290,22 @@ def nms(bboxes, scores, nms_thresh):
     return keep
 
 ## class-agnostic NMS 
-def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
+def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh, max_dets=300):
     # nms
     keep = nms(bboxes, scores, nms_thresh)
-
     scores = scores[keep]
     labels = labels[keep]
     bboxes = bboxes[keep]
 
+    # max dets
+    scores = scores[:max_dets]
+    labels = labels[:max_dets]
+    bboxes = bboxes[:max_dets]
+
     return scores, labels, bboxes
 
 ## class-aware NMS 
-def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
+def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes, max_dets=300):
     # nms
     keep = np.zeros(len(bboxes), dtype=np.int32)
     for i in range(num_classes):
@@ -312,20 +316,24 @@ def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
         c_scores = scores[inds]
         c_keep = nms(c_bboxes, c_scores, nms_thresh)
         keep[inds[c_keep]] = 1
-
     keep = np.where(keep > 0)
     scores = scores[keep]
     labels = labels[keep]
     bboxes = bboxes[keep]
 
+    # max dets
+    scores = scores[:max_dets]
+    labels = labels[:max_dets]
+    bboxes = bboxes[:max_dets]
+
     return scores, labels, bboxes
 
 ## multi-class NMS 
-def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
+def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False, max_dets=300):
     if class_agnostic:
-        return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
+        return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh, max_dets)
     else:
-        return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
+        return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes, max_dets)
 
 
 # ---------------------------- Processor for Deployment ----------------------------