Quellcode durchsuchen

add Albu & modify head init

yjh0410 vor 2 Jahren
Ursprung
Commit
b720fc4928
1 geänderte Dateien mit 12 neuen und 12 gelöschten Zeilen
  1. 12 12
      models/detectors/yolov8/yolov8_pred.py

+ 12 - 12
models/detectors/yolov8/yolov8_pred.py

@@ -1,3 +1,4 @@
+import math
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
@@ -5,9 +6,10 @@ import torch.nn.functional as F
 
 # Single-level pred layer
 class SingleLevelPredLayer(nn.Module):
-    def __init__(self, cls_dim, reg_dim, num_classes, num_coords=4):
+    def __init__(self, cls_dim, reg_dim, stride, num_classes, num_coords=4):
         super().__init__()
         # --------- Basic Parameters ----------
+        self.stride = stride
         self.cls_dim = cls_dim
         self.reg_dim = reg_dim
         self.num_classes = num_classes
@@ -20,14 +22,11 @@ class SingleLevelPredLayer(nn.Module):
         self.init_bias()
         
     def init_bias(self):
-        # Init bias
-        init_prob = 0.01
-        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
-        # cls pred
+        # cls pred bias
         b = self.cls_pred.bias.view(1, -1)
-        b.data.fill_(bias_value.item())
+        b.data.fill_(math.log(5 / self.num_classes / (640. / self.stride) ** 2))
         self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        # reg pred
+        # reg pred bias
         b = self.reg_pred.bias.view(-1, )
         b.data.fill_(1.0)
         self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
@@ -64,14 +63,15 @@ class MultiLevelPredLayer(nn.Module):
             [SingleLevelPredLayer(
                 cls_dim,
                 reg_dim,
+                strides[l],
                 num_classes,
                 num_coords * self.reg_max)
-                for _ in range(num_levels)
+                for l in range(num_levels)
             ])
         ## proj conv
-        self.proj = nn.Parameter(torch.linspace(0, reg_max, reg_max), requires_grad=False)
-        self.proj_conv = nn.Conv2d(self.reg_max, 1, kernel_size=1, bias=False)
-        self.proj_conv.weight = nn.Parameter(self.proj.view([1, reg_max, 1, 1]).clone().detach(), requires_grad=False)
+        proj_init = torch.arange(reg_max, dtype=torch.float)
+        self.proj_conv = nn.Conv2d(self.reg_max, 1, kernel_size=1, bias=False).requires_grad_(False)
+        self.proj_conv.weight.data[:] = nn.Parameter(proj_init.view([1, reg_max, 1, 1]))
 
     def generate_anchors(self, level, fmp_size):
         """
@@ -112,7 +112,7 @@ class MultiLevelPredLayer(nn.Module):
 
             # ----------------------- Decode bbox -----------------------
             B, M = reg_pred.shape[:2]
-            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
+            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max]
             delta_pred = reg_pred.reshape([B, M, 4, self.reg_max])
             # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
             delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()