|
|
@@ -72,30 +72,22 @@ class BasicConv(nn.Module):
|
|
|
|
|
|
# --------------------- GELAN modules (from yolov9) ---------------------
|
|
|
class ADown(nn.Module):
|
|
|
- def __init__(self, in_dim, out_dim, act_type="silu", norm_type="BN", depthwise=False, use_pooling=True):
|
|
|
+ def __init__(self, in_dim, out_dim, act_type="silu", norm_type="BN", depthwise=False):
|
|
|
super().__init__()
|
|
|
inter_dim = out_dim // 2
|
|
|
- self.use_pooling = use_pooling
|
|
|
- if use_pooling:
|
|
|
- self.conv_layer_1 = BasicConv(in_dim // 2, inter_dim,
|
|
|
- kernel_size=3, padding=1, stride=2,
|
|
|
- act_type=act_type, norm_type=norm_type, depthwise=depthwise)
|
|
|
- self.conv_layer_2 = BasicConv(in_dim // 2, inter_dim, kernel_size=1,
|
|
|
- act_type=act_type, norm_type=norm_type, depthwise=depthwise)
|
|
|
- else:
|
|
|
- self.conv_layer = BasicConv(in_dim, out_dim, kernel_size=3, padding=1, stride=2,
|
|
|
- act_type=act_type, norm_type=norm_type, depthwise=depthwise)
|
|
|
+ self.conv_layer_1 = BasicConv(in_dim // 2, inter_dim,
|
|
|
+ kernel_size=3, padding=1, stride=2,
|
|
|
+ act_type=act_type, norm_type=norm_type, depthwise=depthwise)
|
|
|
+ self.conv_layer_2 = BasicConv(in_dim // 2, inter_dim, kernel_size=1,
|
|
|
+ act_type=act_type, norm_type=norm_type, depthwise=depthwise)
|
|
|
def forward(self, x):
|
|
|
- if self.use_pooling:
|
|
|
- x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
|
|
|
- x1,x2 = x.chunk(2, 1)
|
|
|
- x1 = self.conv_layer_1(x1)
|
|
|
- x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
|
|
|
- x2 = self.conv_layer_2(x2)
|
|
|
+ x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
|
|
|
+ x1,x2 = x.chunk(2, 1)
|
|
|
+ x1 = self.conv_layer_1(x1)
|
|
|
+ x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
|
|
|
+ x2 = self.conv_layer_2(x2)
|
|
|
|
|
|
- return torch.cat((x1, x2), 1)
|
|
|
- else:
|
|
|
- return self.conv_layer(x)
|
|
|
+ return torch.cat((x1, x2), 1)
|
|
|
|
|
|
class RepConvN(nn.Module):
|
|
|
"""RepConv is a basic rep-style block, including training and deploy status
|