| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- import numpy as np
- # ---------------------------- NMS ----------------------------
- ## basic NMS
- def nms(bboxes, scores, nms_thresh):
- """"Pure Python NMS."""
- x1 = bboxes[:, 0] #xmin
- y1 = bboxes[:, 1] #ymin
- x2 = bboxes[:, 2] #xmax
- y2 = bboxes[:, 3] #ymax
- areas = (x2 - x1) * (y2 - y1)
- order = scores.argsort()[::-1]
- keep = []
- while order.size > 0:
- i = order[0]
- keep.append(i)
- # compute iou
- xx1 = np.maximum(x1[i], x1[order[1:]])
- yy1 = np.maximum(y1[i], y1[order[1:]])
- xx2 = np.minimum(x2[i], x2[order[1:]])
- yy2 = np.minimum(y2[i], y2[order[1:]])
- w = np.maximum(1e-10, xx2 - xx1)
- h = np.maximum(1e-10, yy2 - yy1)
- inter = w * h
- iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
- #reserve all the boundingbox whose ovr less than thresh
- inds = np.where(iou <= nms_thresh)[0]
- order = order[inds + 1]
- return keep
- ## class-agnostic NMS
- def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
- # nms
- keep = nms(bboxes, scores, nms_thresh)
- scores = scores[keep]
- labels = labels[keep]
- bboxes = bboxes[keep]
- return scores, labels, bboxes
- ## class-aware NMS
- def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
- # nms
- keep = np.zeros(len(bboxes), dtype=np.int32)
- for i in range(num_classes):
- inds = np.where(labels == i)[0]
- if len(inds) == 0:
- continue
- c_bboxes = bboxes[inds]
- c_scores = scores[inds]
- c_keep = nms(c_bboxes, c_scores, nms_thresh)
- keep[inds[c_keep]] = 1
- keep = np.where(keep > 0)
- scores = scores[keep]
- labels = labels[keep]
- bboxes = bboxes[keep]
- return scores, labels, bboxes
- ## multi-class NMS
- def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
- if class_agnostic:
- return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
- else:
- return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
|