Jelajahi Sumber

optimize pred layer

yjh0410 2 tahun lalu
induk
melakukan
25874e27f3
2 mengubah file dengan 130 tambahan dan 97 penghapusan
  1. 63 48
      models/detectors/yolov8/yolov8_pred.py
  2. 67 49
      models/detectors/yolox2/yolox2_pred.py

+ 63 - 48
models/detectors/yolov8/yolov8_pred.py

@@ -7,16 +7,18 @@ import torch.nn.functional as F
 # Single-level pred layer
 class SingleLevelPredLayer(nn.Module):
     def __init__(self,
-                 cls_dim      :int = 256,
-                 reg_dim      :int = 256,
-                 stride       :int = 32,
-                 num_classes  :int = 80,
-                 num_coords   :int = 4):
+                 cls_dim     :int = 256,
+                 reg_dim     :int = 256,
+                 stride      :int = 32,
+                 reg_max     :int = 16,
+                 num_classes :int = 80,
+                 num_coords  :int = 4):
         super().__init__()
         # --------- Basic Parameters ----------
         self.stride = stride
         self.cls_dim = cls_dim
         self.reg_dim = reg_dim
+        self.reg_max = reg_max
         self.num_classes = num_classes
         self.num_coords = num_coords
 
@@ -36,19 +38,57 @@ class SingleLevelPredLayer(nn.Module):
         b.data.fill_(1.0)
         self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
 
-    def forward(self, cls_feat, reg_feat):
+    def generate_anchors(self, fmp_size):
         """
-            in_feats: (Tensor) [B, C, H, W]
+            fmp_size: (List) [H, W]
         """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        anchors += 0.5  # add center offset
+        anchors *= self.stride
+
+        return anchors
+        
+    def forward(self, cls_feat, reg_feat):
+        # pred
         cls_pred = self.cls_pred(cls_feat)
         reg_pred = self.reg_pred(reg_feat)
 
-        return cls_pred, reg_pred
-    
+        # generate anchor boxes: [M, 4]
+        B, _, H, W = cls_pred.size()
+        fmp_size = [H, W]
+        anchors = self.generate_anchors(fmp_size)
+        anchors = anchors.to(cls_pred.device)
+        # stride tensor: [M, 1]
+        stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride
+        
+        # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
+        
+        # output dict
+        outputs = {"pred_cls": cls_pred,            # List(Tensor) [B, M, C]
+                   "pred_reg": reg_pred,            # List(Tensor) [B, M, 4*(reg_max)]
+                   "anchors": anchors,              # List(Tensor) [M, 2]
+                   "strides": self.stride,          # List(Int) = [8, 16, 32]
+                   "stride_tensor": stride_tensor   # List(Tensor) [M, 1]
+                   }
+
+        return outputs
 
 # Multi-level pred layer
 class MultiLevelPredLayer(nn.Module):
-    def __init__(self, cls_dim, reg_dim, strides, num_classes, num_coords=4, num_levels=3, reg_max=16):
+    def __init__(self,
+                 cls_dim,
+                 reg_dim,
+                 strides,
+                 num_classes :int = 80,
+                 num_coords  :int = 4,
+                 num_levels  :int = 3,
+                 reg_max     :int = 16):
         super().__init__()
         # --------- Basic Parameters ----------
         self.cls_dim = cls_dim
@@ -65,6 +105,7 @@ class MultiLevelPredLayer(nn.Module):
             [SingleLevelPredLayer(cls_dim     = cls_dim,
                                   reg_dim     = reg_dim,
                                   stride      = strides[level],
+                                  reg_max     = reg_max,
                                   num_classes = num_classes,
                                   num_coords  = num_coords * reg_max)
                                   for level in range(num_levels)
@@ -74,20 +115,6 @@ class MultiLevelPredLayer(nn.Module):
         self.proj_conv = nn.Conv2d(self.reg_max, 1, kernel_size=1, bias=False).requires_grad_(False)
         self.proj_conv.weight.data[:] = nn.Parameter(proj_init.view([1, reg_max, 1, 1]))
 
-    def generate_anchors(self, level, fmp_size):
-        """
-            fmp_size: (List) [H, W]
-        """
-        # generate grid cells
-        fmp_h, fmp_w = fmp_size
-        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
-        # [H, W, 2] -> [HW, 2]
-        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
-        anchors += 0.5  # add center offset
-        anchors *= self.strides[level]
-
-        return anchors
-        
     def forward(self, cls_feats, reg_feats):
         all_anchors = []
         all_strides = []
@@ -96,25 +123,13 @@ class MultiLevelPredLayer(nn.Module):
         all_box_preds = []
         all_delta_preds = []
         for level in range(self.num_levels):
-            # pred
-            cls_pred, reg_pred = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
-
-            # generate anchor boxes: [M, 4]
-            B, _, H, W = cls_pred.size()
-            fmp_size = [H, W]
-            anchors = self.generate_anchors(level, fmp_size)
-            anchors = anchors.to(cls_pred.device)
-            # stride tensor: [M, 1]
-            stride_tensor = torch.ones_like(anchors[..., :1]) * self.strides[level]
-            
-            # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
-            cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
-            reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
-
-            # ----------------------- Decode bbox -----------------------
-            B, M = reg_pred.shape[:2]
+            # -------------- Single-level prediction --------------
+            outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
+
+            # -------------- Decode bbox --------------
+            B, M = outputs["pred_reg"].shape[:2]
             # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max]
-            delta_pred = reg_pred.reshape([B, M, 4, self.reg_max])
+            delta_pred = outputs["pred_reg"].reshape([B, M, 4, self.reg_max])
             # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
             delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()
             # [B, reg_max, 4, M] -> [B, 1, 4, M]
@@ -122,16 +137,16 @@ class MultiLevelPredLayer(nn.Module):
             # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
             delta_pred = delta_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
             ## tlbr -> xyxy
-            x1y1_pred = anchors[None] - delta_pred[..., :2] * self.strides[level]
-            x2y2_pred = anchors[None] + delta_pred[..., 2:] * self.strides[level]
+            x1y1_pred = outputs["anchors"][None] - delta_pred[..., :2] * self.strides[level]
+            x2y2_pred = outputs["anchors"][None] + delta_pred[..., 2:] * self.strides[level]
             box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
 
-            all_cls_preds.append(cls_pred)
-            all_reg_preds.append(reg_pred)
+            all_cls_preds.append(outputs["pred_cls"])
+            all_reg_preds.append(outputs["pred_reg"])
             all_box_preds.append(box_pred)
             all_delta_preds.append(delta_pred)
-            all_anchors.append(anchors)
-            all_strides.append(stride_tensor)
+            all_anchors.append(outputs["anchors"])
+            all_strides.append(outputs["stride_tensor"])
         
         # output dict
         outputs = {"pred_cls": all_cls_preds,        # List(Tensor) [B, M, C]

+ 67 - 49
models/detectors/yolox2/yolox2_pred.py

@@ -5,7 +5,12 @@ import torch.nn as nn
 
 # Single-level pred layer
 class SingleLevelPredLayer(nn.Module):
-    def __init__(self, cls_dim, reg_dim, stride, num_classes, num_coords=4):
+    def __init__(self,
+                 cls_dim     :int = 256,
+                 reg_dim     :int = 256,
+                 stride      :int = 32,
+                 num_classes :int = 80,
+                 num_coords  :int = 4):
         super().__init__()
         # --------- Basic Parameters ----------
         self.stride = stride
@@ -30,18 +35,56 @@ class SingleLevelPredLayer(nn.Module):
         b.data.fill_(1.0)
         self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
 
-    def forward(self, cls_feat, reg_feat):
+    def generate_anchors(self, fmp_size):
         """
-            in_feats: (Tensor) [B, C, H, W]
+            fmp_size: (List) [H, W]
         """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        anchors += 0.5  # add center offset
+        anchors *= self.stride
+
+        return anchors
+        
+    def forward(self, cls_feat, reg_feat):
+        # pred
         cls_pred = self.cls_pred(cls_feat)
         reg_pred = self.reg_pred(reg_feat)
 
-        return cls_pred, reg_pred
-    
+        # generate anchor boxes: [M, 4]
+        B, _, H, W = cls_pred.size()
+        fmp_size = [H, W]
+        anchors = self.generate_anchors(fmp_size)
+        anchors = anchors.to(cls_pred.device)
+        # stride tensor: [M, 1]
+        stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride
+        
+        # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
+
+        # output dict
+        outputs = {"pred_cls": cls_pred,             # (Tensor) [B, M, C]
+                   "pred_reg": reg_pred,             # (Tensor) [B, M, 4]
+                   "anchors": anchors,               # (Tensor) [M, 2]
+                   "stride": self.stride,            # (Int)
+                   "stride_tensors": stride_tensor   # List(Tensor) [M, 1]
+                   }
+
+        return outputs
+
 # Multi-level pred layer
 class MultiLevelPredLayer(nn.Module):
-    def __init__(self, cls_dim, reg_dim, strides, num_classes, num_coords=4, num_levels=3):
+    def __init__(self,
+                 cls_dim,
+                 reg_dim,
+                 strides,
+                 num_classes :int = 80,
+                 num_coords  :int = 4,
+                 num_levels  :int = 3):
         super().__init__()
         # --------- Basic Parameters ----------
         self.cls_dim = cls_dim
@@ -52,7 +95,7 @@ class MultiLevelPredLayer(nn.Module):
         self.num_levels = num_levels
 
         # ----------- Network Parameters -----------
-        ## pred layers
+        ## multi-level pred layers
         self.multi_level_preds = nn.ModuleList(
             [SingleLevelPredLayer(cls_dim     = cls_dim,
                                   reg_dim     = reg_dim,
@@ -61,20 +104,6 @@ class MultiLevelPredLayer(nn.Module):
                                   num_coords  = num_coords)
                                   for level in range(num_levels)
                                   ])
-
-    def generate_anchors(self, level, fmp_size):
-        """
-            fmp_size: (List) [H, W]
-        """
-        # generate grid cells
-        fmp_h, fmp_w = fmp_size
-        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
-        # [H, W, 2] -> [HW, 2]
-        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
-        anchors += 0.5  # add center offset
-        anchors *= self.strides[level]
-
-        return anchors
         
     def forward(self, cls_feats, reg_feats):
         all_anchors = []
@@ -83,41 +112,30 @@ class MultiLevelPredLayer(nn.Module):
         all_box_preds = []
         all_reg_preds = []
         for level in range(self.num_levels):
-            # pred
-            cls_pred, reg_pred = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
-
-            # generate anchor boxes: [M, 4]
-            B, _, H, W = cls_pred.size()
-            fmp_size = [H, W]
-            anchors = self.generate_anchors(level, fmp_size)
-            anchors = anchors.to(cls_pred.device)
-            # stride tensor: [M, 1]
-            stride_tensor = torch.ones_like(anchors[..., :1]) * self.strides[level]
-            
-            # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
-            cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
-            reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
-
-            # ----------------------- Decode bbox -----------------------
-            ctr_pred = reg_pred[..., :2] * self.strides[level] + anchors[..., :2]
-            wh_pred = torch.exp(reg_pred[..., 2:]) * self.strides[level]
+            # ---------------- Single level prediction ----------------
+            outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
+
+            # ---------------- Decode bbox ----------------
+            ctr_pred = outputs["pred_reg"][..., :2] * self.strides[level] + outputs["anchors"][..., :2]
+            wh_pred = torch.exp(outputs["pred_reg"][..., 2:]) * self.strides[level]
             pred_x1y1 = ctr_pred - wh_pred * 0.5
             pred_x2y2 = ctr_pred + wh_pred * 0.5
             box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
 
-            all_cls_preds.append(cls_pred)
+            # collect results
+            all_cls_preds.append(outputs["pred_cls"])
             all_box_preds.append(box_pred)
-            all_reg_preds.append(reg_pred)
-            all_anchors.append(anchors)
-            all_strides.append(stride_tensor)
+            all_reg_preds.append(outputs["pred_reg"])
+            all_anchors.append(outputs["anchors"])
+            all_strides.append(outputs["stride_tensors"])
         
         # output dict
-        outputs = {"pred_cls": all_cls_preds,        # List(Tensor) [B, M, C]
-                   "pred_box": all_box_preds,        # List(Tensor) [B, M, 4]
-                   "pred_reg": all_reg_preds,        # List(Tensor) [B, M, 4]
-                   "anchors": all_anchors,           # List(Tensor) [M, 2]
-                   "strides": self.strides,          # List(Int) [8, 16, 32]
-                   "stride_tensors": all_strides     # List(Tensor) [M, 1]
+        outputs = {"pred_cls": all_cls_preds,      # List(Tensor) [B, M, C]
+                   "pred_box": all_box_preds,      # List(Tensor) [B, M, 4]
+                   "pred_reg": all_reg_preds,      # List(Tensor) [B, M, 4]
+                   "anchors": all_anchors,         # List(Tensor) [M, 2]
+                   "strides": self.strides,        # List(Int) [8, 16, 32]
+                   "stride_tensors": all_strides   # List(Tensor) [M, 1]
                    }
 
         return outputs