vis_tools.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import cv2
  2. import os
  3. import torch
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. # -------------------------- For Detection Task --------------------------
  7. ## visualize the input data during the training stage
  8. def vis_data(images, targets, masks=None, class_labels=None, normalized_coord=False, box_format='xyxy'):
  9. """
  10. images: (tensor) [B, 3, H, W]
  11. masks: (Tensor) [B, H, W]
  12. targets: (list) a list of targets
  13. """
  14. batch_size = images.size(0)
  15. np.random.seed(0)
  16. class_colors = [(np.random.randint(255),
  17. np.random.randint(255),
  18. np.random.randint(255)) for _ in range(80)]
  19. pixel_means = [0.485, 0.456, 0.406]
  20. pixel_std = [0.229, 0.224, 0.225]
  21. for bi in range(batch_size):
  22. target = targets[bi]
  23. # to numpy
  24. image = images[bi].permute(1, 2, 0).cpu().numpy()
  25. not_mask = ~masks[bi]
  26. img_h = not_mask.cumsum(0, dtype=torch.int32)[-1, 0]
  27. img_w = not_mask.cumsum(1, dtype=torch.int32)[0, -1]
  28. # denormalize
  29. image = (image * pixel_std + pixel_means) * 255
  30. image = image[:, :, (2, 1, 0)].astype(np.uint8)
  31. image = image.copy()
  32. tgt_boxes = target['boxes'].float()
  33. tgt_labels = target['labels'].long()
  34. for box, label in zip(tgt_boxes, tgt_labels):
  35. box_ = box.clone()
  36. if normalized_coord:
  37. box_[..., [0, 2]] *= img_w
  38. box_[..., [1, 3]] *= img_h
  39. if box_format == 'xywh':
  40. box_x1y1 = box_[..., :2] - box_[..., 2:] * 0.5
  41. box_x2y2 = box_[..., :2] + box_[..., 2:] * 0.5
  42. box_ = torch.cat([box_x1y1, box_x2y2], dim=-1)
  43. x1, y1, x2, y2 = box_.long().cpu().numpy()
  44. cls_id = label.item()
  45. color = class_colors[cls_id]
  46. # draw box
  47. cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
  48. if class_labels is not None:
  49. class_name = class_labels[cls_id]
  50. # plot title bbox
  51. t_size = cv2.getTextSize(class_name, 0, fontScale=1, thickness=2)[0]
  52. cv2.rectangle(image, (x1, y1-t_size[1]), (int(x1 + t_size[0] * 0.4), y1), color, -1)
  53. # put the test on the title bbox
  54. cv2.putText(image, class_name, (x1, y1 - 5), 0, 0.4, (0, 0, 0), 1, lineType=cv2.LINE_AA)
  55. cv2.imshow('train target', image)
  56. cv2.waitKey(0)
  57. ## plot bbox & label on image
  58. def plot_bbox_labels(img, bbox, label=None, cls_color=None, text_scale=0.4):
  59. x1, y1, x2, y2 = bbox
  60. x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
  61. t_size = cv2.getTextSize(label, 0, fontScale=1, thickness=2)[0]
  62. # plot bbox
  63. cv2.rectangle(img, (x1, y1), (x2, y2), cls_color, 2)
  64. if label is not None:
  65. # plot title bbox
  66. cv2.rectangle(img, (x1, y1-t_size[1]), (int(x1 + t_size[0] * text_scale), y1), cls_color, -1)
  67. # put the test on the title bbox
  68. cv2.putText(img, label, (int(x1), int(y1 - 5)), 0, text_scale, (0, 0, 0), 1, lineType=cv2.LINE_AA)
  69. return img
  70. ## visualize detection
  71. def visualize(img,
  72. bboxes,
  73. scores,
  74. labels,
  75. vis_thresh,
  76. class_colors,
  77. class_names):
  78. ts = 0.4
  79. for i, bbox in enumerate(bboxes):
  80. if scores[i] > vis_thresh:
  81. cls_id = int(labels[i])
  82. cls_color = class_colors[cls_id]
  83. mess = '%s: %.2f' % (class_names[cls_id], scores[i])
  84. img = plot_bbox_labels(img, bbox, mess, cls_color, text_scale=ts)
  85. return img
  86. ## convert feature to he heatmap
  87. def convert_feature_heatmap(feature):
  88. """
  89. feature: (ndarray) [H, W, C]
  90. """
  91. heatmap = None
  92. return heatmap
  93. ## draw feature on the image
  94. def draw_feature(img, features, save=None):
  95. """
  96. img: (ndarray & cv2.Mat) [H, W, C], where the C is 3 for RGB or 1 for Gray.
  97. features: (List[ndarray]). It is a list of the multiple feature map whose shape is [H, W, C].
  98. save: (bool) save the result or not.
  99. """
  100. img_h, img_w = img.shape[:2]
  101. for i, fmp in enumerate(features):
  102. hmp = convert_feature_heatmap(fmp)
  103. hmp = cv2.resize(hmp, (img_w, img_h))
  104. hmp = hmp.astype(np.uint8)*255
  105. hmp_rgb = cv2.applyColorMap(hmp, cv2.COLORMAP_JET)
  106. superimposed_img = hmp_rgb * 0.4 + img
  107. # show the heatmap
  108. plt.imshow(hmp)
  109. plt.close()
  110. # show the image with heatmap
  111. cv2.imshow("image with heatmap", superimposed_img)
  112. cv2.waitKey(0)
  113. cv2.destroyAllWindows()
  114. if save:
  115. save_dir = 'feature_heatmap'
  116. os.makedirs(save_dir, exist_ok=True)
  117. cv2.imwrite(os.path.join(save_dir, 'feature_{}.png'.format(i) ), superimposed_img)
  118. # -------------------------- For Tracking Task --------------------------
  119. def get_color(idx):
  120. idx = idx * 3
  121. color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
  122. return color
  123. def plot_tracking(image, tlwhs, obj_ids, scores=None, frame_id=0, fps=0., ids2=None):
  124. im = np.ascontiguousarray(np.copy(image))
  125. im_h, im_w = im.shape[:2]
  126. top_view = np.zeros([im_w, im_w, 3], dtype=np.uint8) + 255
  127. #text_scale = max(1, image.shape[1] / 1600.)
  128. #text_thickness = 2
  129. #line_thickness = max(1, int(image.shape[1] / 500.))
  130. text_scale = 2
  131. text_thickness = 2
  132. line_thickness = 3
  133. radius = max(5, int(im_w/140.))
  134. cv2.putText(im, 'frame: %d fps: %.2f num: %d' % (frame_id, fps, len(tlwhs)),
  135. (0, int(15 * text_scale)), cv2.FONT_HERSHEY_PLAIN, 2, (0, 0, 255), thickness=2)
  136. for i, tlwh in enumerate(tlwhs):
  137. x1, y1, w, h = tlwh
  138. intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
  139. obj_id = int(obj_ids[i])
  140. id_text = '{}'.format(int(obj_id))
  141. if ids2 is not None:
  142. id_text = id_text + ', {}'.format(int(ids2[i]))
  143. color = get_color(abs(obj_id))
  144. cv2.rectangle(im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness)
  145. cv2.putText(im, id_text, (intbox[0], intbox[1]), cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 0, 255),
  146. thickness=text_thickness)
  147. return im