nms.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import numpy as np
  2. def nms(bboxes, scores, nms_thresh):
  3. """"Pure Python NMS."""
  4. x1 = bboxes[:, 0] #xmin
  5. y1 = bboxes[:, 1] #ymin
  6. x2 = bboxes[:, 2] #xmax
  7. y2 = bboxes[:, 3] #ymax
  8. areas = (x2 - x1) * (y2 - y1)
  9. order = scores.argsort()[::-1]
  10. keep = []
  11. while order.size > 0:
  12. i = order[0]
  13. keep.append(i)
  14. # compute iou
  15. xx1 = np.maximum(x1[i], x1[order[1:]])
  16. yy1 = np.maximum(y1[i], y1[order[1:]])
  17. xx2 = np.minimum(x2[i], x2[order[1:]])
  18. yy2 = np.minimum(y2[i], y2[order[1:]])
  19. w = np.maximum(1e-10, xx2 - xx1)
  20. h = np.maximum(1e-10, yy2 - yy1)
  21. inter = w * h
  22. iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
  23. #reserve all the boundingbox whose ovr less than thresh
  24. inds = np.where(iou <= nms_thresh)[0]
  25. order = order[inds + 1]
  26. return keep
  27. def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
  28. # nms
  29. keep = nms(bboxes, scores, nms_thresh)
  30. scores = scores[keep]
  31. labels = labels[keep]
  32. bboxes = bboxes[keep]
  33. return scores, labels, bboxes
  34. def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
  35. # nms
  36. keep = np.zeros(len(bboxes), dtype=np.int)
  37. for i in range(num_classes):
  38. inds = np.where(labels == i)[0]
  39. if len(inds) == 0:
  40. continue
  41. c_bboxes = bboxes[inds]
  42. c_scores = scores[inds]
  43. c_keep = nms(c_bboxes, c_scores, nms_thresh)
  44. keep[inds[c_keep]] = 1
  45. keep = np.where(keep > 0)
  46. scores = scores[keep]
  47. labels = labels[keep]
  48. bboxes = bboxes[keep]
  49. return scores, labels, bboxes
  50. def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
  51. if class_agnostic:
  52. return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
  53. else:
  54. return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)