浏览代码

debug post-process

yjh0410 2 年之前
父节点
当前提交
4ed895d30d
共有 1 个文件被更改,包括 6 次插入13 次删除
  1. 6 13
      models/detectors/yolov7/yolov7.py

+ 6 - 13
models/detectors/yolov7/yolov7.py

@@ -93,14 +93,7 @@ class YOLOv7(nn.Module):
         all_bboxes = []
         
         for obj_pred_i, cls_pred_i, box_pred_i in zip(obj_preds, cls_preds, box_preds):
-            conf_i = obj_pred_i[..., 0].sigmoid()
-            conf_keep_i = conf_i > self.conf_thresh
-
-            obj_pred_i = obj_pred_i[conf_keep_i]
-            cls_pred_i = cls_pred_i[conf_keep_i]
-            box_pred_i = box_pred_i[conf_keep_i]
-
-            # (H x W x C,)
+            # (H x W x KA x C,)
             scores_i = (torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid())).flatten()
 
             # Keep top k top scoring indices only.
@@ -108,13 +101,13 @@ class YOLOv7(nn.Module):
 
             # torch.sort is actually faster than .topk (at least on GPUs)
             predicted_prob, topk_idxs = scores_i.sort(descending=True)
-            scores = predicted_prob[:num_topk]
+            topk_scores = predicted_prob[:num_topk]
             topk_idxs = topk_idxs[:num_topk]
 
-            # # filter out the proposals with low confidence score
-            # keep_idxs = scores > self.conf_thresh
-            # scores = scores[keep_idxs]
-            # topk_idxs = topk_idxs[keep_idxs]
+            # filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            scores = topk_scores[keep_idxs]
+            topk_idxs = topk_idxs[keep_idxs]
 
             anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
             labels = topk_idxs % self.num_classes