|
|
@@ -2,6 +2,7 @@ import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
|
|
+import torchvision
|
|
|
|
|
|
import cv2
|
|
|
import math
|
|
|
@@ -325,6 +326,96 @@ def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnost
|
|
|
else:
|
|
|
return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
|
|
|
|
|
|
+def non_max_suppression(
|
|
|
+ prediction,
|
|
|
+ conf_thres=0.25,
|
|
|
+ iou_thres=0.45,
|
|
|
+ classes=None,
|
|
|
+ agnostic=False,
|
|
|
+ multi_label=False,
|
|
|
+ max_det=300,
|
|
|
+ nc=0, # number of classes (optional)
|
|
|
+ max_nms=30000,
|
|
|
+ max_wh=7680,
|
|
|
+):
|
|
|
+ """
|
|
|
+ Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
|
|
|
+ containing the predicted boxes, classes, and masks. The tensor should be in the format
|
|
|
+ output by a model, such as YOLO.
|
|
|
+ conf_thres (float): The confidence threshold below which boxes will be filtered out.
|
|
|
+ Valid values are between 0.0 and 1.0.
|
|
|
+ iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
|
|
|
+ Valid values are between 0.0 and 1.0.
|
|
|
+ classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
|
|
|
+ agnostic (bool): If True, the model is agnostic to the number of classes, and all
|
|
|
+ classes will be considered as one.
|
|
|
+ multi_label (bool): If True, each box may have multiple labels.
|
|
|
+ labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
|
|
|
+ list contains the apriori labels for a given image. The list should be in the format
|
|
|
+ output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
|
|
|
+ max_det (int): The maximum number of boxes to keep after NMS.
|
|
|
+ nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
|
|
|
+ max_time_img (float): The maximum time (seconds) for processing one image.
|
|
|
+ max_nms (int): The maximum number of boxes into torchvision.ops.nms().
|
|
|
+ max_wh (int): The maximum box width and height in pixels
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
|
|
|
+ shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
|
|
|
+ (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
|
|
|
+ """
|
|
|
+
|
|
|
+ # Checks
|
|
|
+ assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
|
|
|
+ assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
|
|
|
+
|
|
|
+ device = prediction.device # [N, C+4]
|
|
|
+ nc = nc or (prediction.shape[1] - 4) # number of classes
|
|
|
+ xc = prediction[:, 4:].amax(1) > conf_thres # candidates
|
|
|
+
|
|
|
+ # Settings
|
|
|
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
|
|
+ output = torch.zeros((0, 6), device=device)
|
|
|
+
|
|
|
+ # Apply constraints
|
|
|
+ prediction = prediction[xc] # confidence
|
|
|
+
|
|
|
+ # If none remain process next image
|
|
|
+ if not prediction.shape[0]:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # Detections matrix nx6 (xyxy, conf, cls)
|
|
|
+ box, cls = prediction.split((4, nc), 1)
|
|
|
+
|
|
|
+ if multi_label:
|
|
|
+ i, j = torch.where(cls > conf_thres)
|
|
|
+ prediction = torch.cat((box[i], prediction[i, 4 + j, None], j[:, None].float()), 1)
|
|
|
+ else: # best class only
|
|
|
+ conf, j = cls.max(1, keepdim=True)
|
|
|
+ prediction = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
|
|
|
+
|
|
|
+ # Filter by class
|
|
|
+ if classes is not None:
|
|
|
+ prediction = prediction[(prediction[:, 5:6] == torch.tensor(classes, device=device)).any(1)]
|
|
|
+
|
|
|
+ # Check shape
|
|
|
+ n = prediction.shape[0] # number of boxes
|
|
|
+ if n > max_nms: # excess boxes
|
|
|
+ prediction = prediction[prediction[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
|
|
|
+
|
|
|
+ # Batched NMS
|
|
|
+ c = prediction[:, 5:6] * (0 if agnostic else max_wh) # classes
|
|
|
+ boxes, scores = prediction[:, :4] + c, prediction[:, 4] # boxes (offset by class), scores
|
|
|
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
|
|
+ i = i[:max_det] # limit detections
|
|
|
+
|
|
|
+ output = prediction[i]
|
|
|
+
|
|
|
+ return output
|
|
|
+
|
|
|
|
|
|
# ---------------------------- Processor for Deployment ----------------------------
|
|
|
## Pre-processer
|