yjh0410 hace 9 meses
padre
commit
55e928d1c6
Se han modificado 2 ficheros con 14 adiciones y 29 borrados
  1. 7 13
      yolo/models/yolov8/yolov8.py
  2. 7 16
      yolo/models/yolov9/gelan.py

+ 7 - 13
yolo/models/yolov8/yolov8.py

@@ -141,18 +141,12 @@ class Yolov8(nn.Module):
             all_cls_preds = outputs['pred_cls']
             all_box_preds = outputs['pred_box']
 
-            if deploy:
-                cls_preds = torch.cat(all_cls_preds, dim=1)[0]
-                box_preds = torch.cat(all_box_preds, dim=1)[0]
-                outputs = torch.cat([box_preds, cls_preds.sigmoid()], dim=-1)
-                
-            else:
-                # post process
-                bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
-                outputs = {
-                    "scores": scores,
-                    "labels": labels,
-                    "bboxes": bboxes
-                }
+            # post process
+            bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
         
         return outputs 

+ 7 - 16
yolo/models/yolov9/gelan.py

@@ -144,22 +144,13 @@ class GElan(nn.Module):
             all_cls_preds = outputs['pred_cls']
             all_box_preds = outputs['pred_box']
 
-            if self.deploy:
-                cls_preds = torch.cat(all_cls_preds, dim=1)[0]
-                box_preds = torch.cat(all_box_preds, dim=1)[0]
-                scores = cls_preds.sigmoid()
-                bboxes = box_preds
-                # [n_anchors_all, 4 + C]
-                outputs = torch.cat([bboxes, scores], dim=-1)
-
-            else:
-                # post process
-                bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
-                outputs = {
-                    "scores": scores,
-                    "labels": labels,
-                    "bboxes": bboxes
-                }
+            # post process
+            bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
         
         return outputs