Răsfoiți Sursa

debug yolov9

yjh0410 9 luni în urmă
părinte
comite
55e928d1c6
2 a modificat fișierele cu 14 adăugiri și 29 ștergeri
  1. 7 13
      yolo/models/yolov8/yolov8.py
  2. 7 16
      yolo/models/yolov9/gelan.py

+ 7 - 13
yolo/models/yolov8/yolov8.py

@@ -141,18 +141,12 @@ class Yolov8(nn.Module):
             all_cls_preds = outputs['pred_cls']
             all_box_preds = outputs['pred_box']
 
-            if deploy:
-                cls_preds = torch.cat(all_cls_preds, dim=1)[0]
-                box_preds = torch.cat(all_box_preds, dim=1)[0]
-                outputs = torch.cat([box_preds, cls_preds.sigmoid()], dim=-1)
-                
-            else:
-                # post process
-                bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
-                outputs = {
-                    "scores": scores,
-                    "labels": labels,
-                    "bboxes": bboxes
-                }
+            # post process
+            bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
         
         return outputs 

+ 7 - 16
yolo/models/yolov9/gelan.py

@@ -144,22 +144,13 @@ class GElan(nn.Module):
             all_cls_preds = outputs['pred_cls']
             all_box_preds = outputs['pred_box']
 
-            if self.deploy:
-                cls_preds = torch.cat(all_cls_preds, dim=1)[0]
-                box_preds = torch.cat(all_box_preds, dim=1)[0]
-                scores = cls_preds.sigmoid()
-                bboxes = box_preds
-                # [n_anchors_all, 4 + C]
-                outputs = torch.cat([bboxes, scores], dim=-1)
-
-            else:
-                # post process
-                bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
-                outputs = {
-                    "scores": scores,
-                    "labels": labels,
-                    "bboxes": bboxes
-                }
+            # post process
+            bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
         
         return outputs