nms_ops.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import numpy as np
  2. # ---------------------------- NMS ----------------------------
  3. ## basic NMS
  4. def nms(bboxes, scores, nms_thresh):
  5. """"Pure Python NMS."""
  6. x1 = bboxes[:, 0] #xmin
  7. y1 = bboxes[:, 1] #ymin
  8. x2 = bboxes[:, 2] #xmax
  9. y2 = bboxes[:, 3] #ymax
  10. areas = (x2 - x1) * (y2 - y1)
  11. order = scores.argsort()[::-1]
  12. keep = []
  13. while order.size > 0:
  14. i = order[0]
  15. keep.append(i)
  16. # compute iou
  17. xx1 = np.maximum(x1[i], x1[order[1:]])
  18. yy1 = np.maximum(y1[i], y1[order[1:]])
  19. xx2 = np.minimum(x2[i], x2[order[1:]])
  20. yy2 = np.minimum(y2[i], y2[order[1:]])
  21. w = np.maximum(1e-10, xx2 - xx1)
  22. h = np.maximum(1e-10, yy2 - yy1)
  23. inter = w * h
  24. iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
  25. #reserve all the boundingbox whose ovr less than thresh
  26. inds = np.where(iou <= nms_thresh)[0]
  27. order = order[inds + 1]
  28. return keep
  29. ## class-agnostic NMS
  30. def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
  31. # nms
  32. keep = nms(bboxes, scores, nms_thresh)
  33. scores = scores[keep]
  34. labels = labels[keep]
  35. bboxes = bboxes[keep]
  36. return scores, labels, bboxes
  37. ## class-aware NMS
  38. def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
  39. # nms
  40. keep = np.zeros(len(bboxes), dtype=np.int32)
  41. for i in range(num_classes):
  42. inds = np.where(labels == i)[0]
  43. if len(inds) == 0:
  44. continue
  45. c_bboxes = bboxes[inds]
  46. c_scores = scores[inds]
  47. c_keep = nms(c_bboxes, c_scores, nms_thresh)
  48. keep[inds[c_keep]] = 1
  49. keep = np.where(keep > 0)
  50. scores = scores[keep]
  51. labels = labels[keep]
  52. bboxes = bboxes[keep]
  53. return scores, labels, bboxes
  54. ## multi-class NMS
  55. def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
  56. if class_agnostic:
  57. return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
  58. else:
  59. return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)