vis_tools.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import cv2
  2. import os
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from dataset.coco import coco_class_index, coco_class_labels
  6. # draw bbox & label on the image
  7. def plot_bbox_labels(img, bbox, label, cls_color, test_scale=0.4):
  8. x1, y1, x2, y2 = bbox
  9. x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
  10. t_size = cv2.getTextSize(label, 0, fontScale=1, thickness=2)[0]
  11. # plot bbox
  12. cv2.rectangle(img, (x1, y1), (x2, y2), cls_color, 2)
  13. # plot title bbox
  14. cv2.rectangle(img, (x1, y1-t_size[1]), (int(x1 + t_size[0] * test_scale), y1), cls_color, -1)
  15. # put the test on the title bbox
  16. cv2.putText(img, label, (int(x1), int(y1 - 5)), 0, test_scale, (0, 0, 0), 1, lineType=cv2.LINE_AA)
  17. return img
  18. # visualize the detection results
  19. def visualize(img, bboxes, scores, labels, class_colors, vis_thresh=0.3):
  20. ts = 0.4
  21. for i, bbox in enumerate(bboxes):
  22. if scores[i] > vis_thresh:
  23. cls_color = class_colors[int(labels[i])]
  24. cls_id = coco_class_index[int(labels[i])]
  25. mess = '%s: %.2f' % (coco_class_labels[cls_id], scores[i])
  26. img = plot_bbox_labels(img, bbox, mess, cls_color, test_scale=ts)
  27. return img
  28. # visualize the input data during the training stage
  29. def vis_data(images, targets):
  30. """
  31. images: (tensor) [B, 3, H, W]
  32. targets: (list) a list of targets
  33. """
  34. batch_size = images.size(0)
  35. np.random.seed(0)
  36. class_colors = [(np.random.randint(255),
  37. np.random.randint(255),
  38. np.random.randint(255)) for _ in range(20)]
  39. for bi in range(batch_size):
  40. # to numpy
  41. image = images[bi].permute(1, 2, 0).cpu().numpy()
  42. target = targets[bi]
  43. image = image.astype(np.uint8)
  44. image = image.copy()
  45. tgt_boxes = target['boxes']
  46. tgt_labels = target['labels']
  47. for box, label in zip(tgt_boxes, tgt_labels):
  48. x1, y1, x2, y2 = box
  49. cls_id = int(label)
  50. x1, y1 = int(x1), int(y1)
  51. x2, y2 = int(x2), int(y2)
  52. color = class_colors[cls_id]
  53. # draw box
  54. cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
  55. cv2.imshow('train target', image)
  56. cv2.waitKey(0)
  57. # convert feature to he heatmap
  58. def convert_feature_heatmap(feature):
  59. """
  60. feature: (ndarray) [H, W, C]
  61. """
  62. heatmap = None
  63. return heatmap
  64. # draw feature on the image
  65. def draw_feature(img, features, save=None):
  66. """
  67. img: (ndarray & cv2.Mat) [H, W, C], where the C is 3 for RGB or 1 for Gray.
  68. features: (List[ndarray]). It is a list of the multiple feature map whose shape is [H, W, C].
  69. save: (bool) save the result or not.
  70. """
  71. img_h, img_w = img.shape[:2]
  72. for i, fmp in enumerate(features):
  73. hmp = convert_feature_heatmap(fmp)
  74. hmp = cv2.resize(hmp, (img_w, img_h))
  75. hmp = hmp.astype(np.uint8)*255
  76. hmp_rgb = cv2.applyColorMap(hmp, cv2.COLORMAP_JET)
  77. superimposed_img = hmp_rgb * 0.4 + img
  78. # show the heatmap
  79. plt.imshow(hmp)
  80. plt.close()
  81. # show the image with heatmap
  82. cv2.imshow("image with heatmap", superimposed_img)
  83. cv2.waitKey(0)
  84. cv2.destroyAllWindows()
  85. if save:
  86. save_dir = 'feature_heatmap'
  87. os.makedirs(save_dir, exist_ok=True)
  88. cv2.imwrite(os.path.join(save_dir, 'feature_{}.png'.format(i) ), superimposed_img)