yjh0410 vor 2 Jahren
Ursprung
Commit
23b45c7783
4 geänderte Dateien mit 107 neuen und 5 gelöschten Zeilen
  1. 2 2
      eval.py
  2. 13 2
      models/detectors/yolov8/yolov8.py
  3. 1 1
      test.py
  4. 91 0
      utils/misc.py

+ 2 - 2
eval.py

@@ -32,9 +32,9 @@ def parse_args():
                         help='build yolo')
     parser.add_argument('--weight', default=None,
                         type=str, help='Trained state_dict file path to open')
-    parser.add_argument('-ct', '--conf_thresh', default=0.005, type=float,
+    parser.add_argument('-ct', '--conf_thresh', default=0.001, type=float,
                         help='confidence threshold')
-    parser.add_argument('-nt', '--nms_thresh', default=0.6, type=float,
+    parser.add_argument('-nt', '--nms_thresh', default=0.7, type=float,
                         help='NMS threshold')
     parser.add_argument('--topk', default=1000, type=int,
                         help='topk candidates dets of each level before NMS')

+ 13 - 2
models/detectors/yolov8/yolov8.py

@@ -10,7 +10,7 @@ from .yolov8_head import build_det_head
 from .yolov8_pred import build_pred_layer
 
 # --------------- External components ---------------
-from utils.misc import multiclass_nms
+from utils.misc import multiclass_nms, non_max_suppression
 
 
 # YOLOv8
@@ -151,8 +151,19 @@ class YOLOv8(nn.Module):
 
             return outputs
         else:
+            cls_preds = torch.sigmoid(torch.cat(all_cls_preds, dim=1))[0]
+            box_preds = torch.cat(all_box_preds, dim=1)[0]
+            predictions = torch.cat([box_preds, cls_preds], dim=-1)
+            outputs = non_max_suppression(predictions,
+                                          self.conf_thresh,
+                                          self.nms_thresh,
+                                          agnostic=self.nms_class_agnostic,
+                                          max_det=300,
+                                          classes=None)
+            bboxes, scores, labels = outputs[:, :4], outputs[:, 4], outputs[:, 5]
+            
             # post process
-            bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
+            # bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
         
             return bboxes, scores, labels
 

+ 1 - 1
test.py

@@ -43,7 +43,7 @@ def parse_args():
                         help='build yolo')
     parser.add_argument('--weight', default=None,
                         type=str, help='Trained state_dict file path to open')
-    parser.add_argument('-ct', '--conf_thresh', default=0.1, type=float,
+    parser.add_argument('-ct', '--conf_thresh', default=0.25, type=float,
                         help='confidence threshold')
     parser.add_argument('-nt', '--nms_thresh', default=0.5, type=float,
                         help='NMS threshold')

+ 91 - 0
utils/misc.py

@@ -2,6 +2,7 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from torch.utils.data import DataLoader, DistributedSampler
+import torchvision
 
 import cv2
 import math
@@ -325,6 +326,96 @@ def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnost
     else:
         return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
 
+def non_max_suppression(
+        prediction,
+        conf_thres=0.25,
+        iou_thres=0.45,
+        classes=None,
+        agnostic=False,
+        multi_label=False,
+        max_det=300,
+        nc=0,  # number of classes (optional)
+        max_nms=30000,
+        max_wh=7680,
+):
+    """
+    Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
+
+    Args:
+        prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
+            containing the predicted boxes, classes, and masks. The tensor should be in the format
+            output by a model, such as YOLO.
+        conf_thres (float): The confidence threshold below which boxes will be filtered out.
+            Valid values are between 0.0 and 1.0.
+        iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
+            Valid values are between 0.0 and 1.0.
+        classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
+        agnostic (bool): If True, the model is agnostic to the number of classes, and all
+            classes will be considered as one.
+        multi_label (bool): If True, each box may have multiple labels.
+        labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
+            list contains the apriori labels for a given image. The list should be in the format
+            output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
+        max_det (int): The maximum number of boxes to keep after NMS.
+        nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
+        max_time_img (float): The maximum time (seconds) for processing one image.
+        max_nms (int): The maximum number of boxes into torchvision.ops.nms().
+        max_wh (int): The maximum box width and height in pixels
+
+    Returns:
+        (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
+            shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
+            (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
+    """
+
+    # Checks
+    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
+    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
+
+    device = prediction.device  # [N, C+4]
+    nc = nc or (prediction.shape[1] - 4)  # number of classes
+    xc = prediction[:, 4:].amax(1) > conf_thres  # candidates
+
+    # Settings
+    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
+    output = torch.zeros((0, 6), device=device)
+
+    # Apply constraints
+    prediction = prediction[xc]  # confidence
+
+    # If none remain process next image
+    if not prediction.shape[0]:
+        pass
+
+    # Detections matrix nx6 (xyxy, conf, cls)
+    box, cls = prediction.split((4, nc), 1)
+
+    if multi_label:
+        i, j = torch.where(cls > conf_thres)
+        prediction = torch.cat((box[i], prediction[i, 4 + j, None], j[:, None].float()), 1)
+    else:  # best class only
+        conf, j = cls.max(1, keepdim=True)
+        prediction = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
+
+    # Filter by class
+    if classes is not None:
+        prediction = prediction[(prediction[:, 5:6] == torch.tensor(classes, device=device)).any(1)]
+
+    # Check shape
+    n = prediction.shape[0]  # number of boxes
+    if n > max_nms:  # excess boxes
+        prediction = prediction[prediction[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes
+
+    # Batched NMS
+    c = prediction[:, 5:6] * (0 if agnostic else max_wh)  # classes
+    boxes, scores = prediction[:, :4] + c, prediction[:, 4]  # boxes (offset by class), scores
+    i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
+    i = i[:max_det]  # limit detections
+
+    output = prediction[i]
+
+    return output
+
 
 # ---------------------------- Processor for Deployment ----------------------------
 ## Pre-processer