|
|
@@ -79,9 +79,6 @@ class Conv(nn.Module):
|
|
|
|
|
|
# ELAN Block
|
|
|
class ELANBlock(nn.Module):
|
|
|
- """
|
|
|
- ELAN BLock of YOLOv7's backbone
|
|
|
- """
|
|
|
def __init__(self, in_dim, out_dim, expand_ratio=0.5, act_type='silu', norm_type='BN', depthwise=False):
|
|
|
super(ELANBlock, self).__init__()
|
|
|
inter_dim = int(in_dim * expand_ratio)
|
|
|
@@ -101,18 +98,10 @@ class ELANBlock(nn.Module):
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
- """
|
|
|
- Input:
|
|
|
- x: [B, C, H, W]
|
|
|
- Output:
|
|
|
- out: [B, 2C, H, W]
|
|
|
- """
|
|
|
x1 = self.cv1(x)
|
|
|
x2 = self.cv2(x)
|
|
|
x3 = self.cv3(x2)
|
|
|
x4 = self.cv4(x3)
|
|
|
-
|
|
|
- # [B, C, H, W] -> [B, 2C, H, W]
|
|
|
out = self.out(torch.cat([x1, x2, x3, x4], dim=1))
|
|
|
|
|
|
return out
|
|
|
@@ -131,17 +120,8 @@ class DownSample(nn.Module):
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
- """
|
|
|
- Input:
|
|
|
- x: [B, C, H, W]
|
|
|
- Output:
|
|
|
- out: [B, C, H//2, W//2]
|
|
|
- """
|
|
|
- # [B, C, H, W] -> [B, C//2, H//2, W//2]
|
|
|
x1 = self.cv1(self.mp(x))
|
|
|
x2 = self.cv2(x)
|
|
|
-
|
|
|
- # [B, C, H//2, W//2]
|
|
|
out = torch.cat([x1, x2], dim=1)
|
|
|
|
|
|
return out
|
|
|
@@ -149,9 +129,6 @@ class DownSample(nn.Module):
|
|
|
|
|
|
# ELAN Block for PaFPN
|
|
|
class ELANBlockFPN(nn.Module):
|
|
|
- """
|
|
|
- ELAN BLock of YOLOv7's head
|
|
|
- """
|
|
|
def __init__(self, in_dim, out_dim, act_type='silu', norm_type='BN', depthwise=False):
|
|
|
super(ELANBlockFPN, self).__init__()
|
|
|
# Basic parameters
|
|
|
@@ -181,12 +158,6 @@ class ELANBlockFPN(nn.Module):
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
- """
|
|
|
- Input:
|
|
|
- x: [B, C_in, H, W]
|
|
|
- Output:
|
|
|
- out: [B, C_out, H, W]
|
|
|
- """
|
|
|
x1 = self.cv1(x)
|
|
|
x2 = self.cv2(x)
|
|
|
inter_outs = [x1, x2]
|
|
|
@@ -194,8 +165,6 @@ class ELANBlockFPN(nn.Module):
|
|
|
y1 = inter_outs[-1]
|
|
|
y2 = m(y1)
|
|
|
inter_outs.append(y2)
|
|
|
-
|
|
|
- # [B, C_in, H, W] -> [B, C_out, H, W]
|
|
|
out = self.out(torch.cat(inter_outs, dim=1))
|
|
|
|
|
|
return out
|
|
|
@@ -214,17 +183,8 @@ class DownSampleFPN(nn.Module):
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
- """
|
|
|
- Input:
|
|
|
- x: [B, C, H, W]
|
|
|
- Output:
|
|
|
- out: [B, 2C, H//2, W//2]
|
|
|
- """
|
|
|
- # [B, C, H, W] -> [B, C//2, H//2, W//2]
|
|
|
x1 = self.cv1(self.mp(x))
|
|
|
x2 = self.cv2(x)
|
|
|
-
|
|
|
- # [B, C, H//2, W//2]
|
|
|
out = torch.cat([x1, x2], dim=1)
|
|
|
|
|
|
return out
|