Ver Fonte

debug YOLOvx

yjh0410 há 2 anos atrás
pai
commit
1c3fce90a6
1 ficheiros alterados com 25 adições e 36 exclusões
  1. 25 36
      models/detectors/yolovx/yolovx.py

+ 25 - 36
models/detectors/yolovx/yolovx.py

@@ -86,6 +86,24 @@ class YOLOvx(nn.Module):
 
         return anchors
         
+    ## decode bbox
+    def decode_bbox(self, reg_pred, anchors, stride):
+        B, M = reg_pred.shape[:2]
+        # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
+        reg_pred = reg_pred.reshape([B, M, 4, self.reg_max])
+        # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+        reg_pred = reg_pred.permute(0, 3, 2, 1).contiguous()
+        # [B, reg_max, 4, M] -> [B, 1, 4, M]
+        reg_pred = self.proj_conv(F.softmax(reg_pred, dim=1))
+        # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+        reg_pred = reg_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
+        ## tlbr -> xyxy
+        x1y1_pred = anchors[None] - reg_pred[..., :2] * stride
+        x2y2_pred = anchors[None] + reg_pred[..., 2:] * stride
+        box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+
+        return box_pred
+    
     ## post-process
     def post_process(self, cls_preds, box_preds):
         """
@@ -169,21 +187,7 @@ class YOLOvx(nn.Module):
             # process preds
             cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
             reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
-
-            # ----------------------- Decode bbox -----------------------
-            B, M = reg_pred.shape[:2]
-            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
-            reg_pred = reg_pred.reshape([B, M, 4, self.reg_max])
-            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
-            reg_pred = reg_pred.permute(0, 3, 2, 1).contiguous()
-            # [B, reg_max, 4, M] -> [B, 1, 4, M]
-            reg_pred = self.proj_conv(F.softmax(reg_pred, dim=1))
-            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
-            reg_pred = reg_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
-            ## tlbr -> xyxy
-            x1y1_pred = anchors[None] - reg_pred[..., :2] * self.stride[level]
-            x2y2_pred = anchors[None] + reg_pred[..., 2:] * self.stride[level]
-            box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+            box_pred = self.decode_bbox(reg_pred, anchors, self.stride[level])
 
             # collect preds
             all_cls_preds.append(cls_pred[0])
@@ -229,34 +233,19 @@ class YOLOvx(nn.Module):
             all_reg_preds = []
             all_box_preds = []
             for level, (cls_feat, reg_feat) in enumerate(zip(cls_feats, reg_feats)):
+                # anchors & stride tensor
+                B, _, H, W = cls_feat.size()
+                anchors = self.generate_anchors(level, [H, W])                         # [M, 4]
+                stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride[level] # [M, 1]
+                
                 # prediction
                 cls_pred = self.cls_preds[level](cls_feat)
                 reg_pred = self.reg_preds[level](reg_feat)
 
-                B, _, H, W = cls_pred.size()
-                # generate anchor boxes: [M, 4]
-                anchors = self.generate_anchors(level, [H, W])
-                # stride tensor: [M, 1]
-                stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride[level]
-                
                 # process preds
                 cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
                 reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
-
-                # ----------------------- Decode bbox -----------------------
-                B, M = reg_pred.shape[:2]
-                # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
-                reg_pred_ = reg_pred.reshape([B, M, 4, self.reg_max])
-                # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
-                reg_pred_ = reg_pred_.permute(0, 3, 2, 1).contiguous()
-                # [B, reg_max, 4, M] -> [B, 1, 4, M]
-                reg_pred_ = self.proj_conv(F.softmax(reg_pred_, dim=1))
-                # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
-                reg_pred_ = reg_pred_.view(B, 4, M).permute(0, 2, 1).contiguous()
-                ## tlbr -> xyxy
-                x1y1_pred = anchors[None] - reg_pred_[..., :2] * self.stride[level]
-                x2y2_pred = anchors[None] + reg_pred_[..., 2:] * self.stride[level]
-                box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+                box_pred = self.decode_bbox(reg_pred, anchors, self.stride[level])
 
                 # collect preds
                 all_cls_preds.append(cls_pred)