yjh0410 1 年之前
父节点
当前提交
aaff3e9652
共有 1 个文件被更改,包括 3 次插入20 次删除
  1. 3 20
      yolo/utils/vis_tools.py

+ 3 - 20
yolo/utils/vis_tools.py

@@ -34,7 +34,7 @@ def visualize(image, bboxes, scores, labels, class_colors, class_names):
     return image
         
 ## Visualize the input data during the training stage
-def vis_data(images, targets, num_classes=80, normalized_bbox=False, color_format='bgr', pixel_mean=None, pixel_std=None, box_format="xyxy"):
+def vis_data(images, targets, num_classes=80, pixel_mean=None, pixel_std=None):
     """
         images: (tensor) [B, 3, H, W]
         targets: (list) a list of targets
@@ -54,31 +54,14 @@ def vis_data(images, targets, num_classes=80, normalized_bbox=False, color_forma
         # denormalize image
         if pixel_mean is not None and pixel_std is not None:
             image = image * pixel_std + pixel_mean
-        
-        if color_format == 'rgb':
-            image = image[..., (2, 1, 0)] # RGB to BGR
-            
+                    
         image = image.astype(np.uint8)
         image = image.copy()
         img_h, img_w = image.shape[:2]
 
         # visualize target
         for box, label in zip(tgt_boxes, tgt_labels):
-            if box_format == "xyxy":
-                x1, y1, x2, y2 = box
-            elif box_format == "xywh":
-                cx, cy, bw, bh = box
-                x1 = cx - 0.5 * bw
-                y1 = cy - 0.5 * bh
-                x2 = cx + 0.5 * bw
-                y2 = cy + 0.5 * bh
-
-            if normalized_bbox:
-                x1 *= img_w
-                y1 *= img_h
-                x2 *= img_w
-                y2 *= img_h
-
+            x1, y1, x2, y2 = box
             x1, y1 = int(x1), int(y1)
             x2, y2 = int(x2), int(y2)
             cls_id = int(label)