|
|
@@ -194,52 +194,26 @@ class BasicConv(nn.Module):
|
|
|
stride=1, # padding
|
|
|
act_type :str = 'lrelu', # activation
|
|
|
norm_type :str = 'BN', # normalization
|
|
|
+ depthwise :bool = False
|
|
|
):
|
|
|
super(BasicConv, self).__init__()
|
|
|
add_bias = False if norm_type else True
|
|
|
- self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
|
|
|
- self.norm = get_norm(norm_type, out_dim)
|
|
|
- self.act = get_activation(act_type)
|
|
|
-
|
|
|
- def forward(self, x):
|
|
|
- return self.act(self.norm(self.conv(x)))
|
|
|
-
|
|
|
-class DepthwiseConv(nn.Module):
|
|
|
- def __init__(self,
|
|
|
- in_dim, # in channels
|
|
|
- out_dim, # out channels
|
|
|
- kernel_size=1, # kernel size
|
|
|
- padding=0, # padding
|
|
|
- stride=1, # padding
|
|
|
- act_type :str = None, # activation
|
|
|
- norm_type :str = 'BN', # normalization
|
|
|
- ):
|
|
|
- super(DepthwiseConv, self).__init__()
|
|
|
- assert in_dim == out_dim
|
|
|
- add_bias = False if norm_type else True
|
|
|
- self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=out_dim, bias=add_bias)
|
|
|
- self.norm = get_norm(norm_type, out_dim)
|
|
|
- self.act = get_activation(act_type)
|
|
|
-
|
|
|
- def forward(self, x):
|
|
|
- return self.act(self.norm(self.conv(x)))
|
|
|
-
|
|
|
-class PointwiseConv(nn.Module):
|
|
|
- def __init__(self,
|
|
|
- in_dim, # in channels
|
|
|
- out_dim, # out channels
|
|
|
- act_type :str = 'lrelu', # activation
|
|
|
- norm_type :str = 'BN', # normalization
|
|
|
- ):
|
|
|
- super(DepthwiseConv, self).__init__()
|
|
|
- assert in_dim == out_dim
|
|
|
- add_bias = False if norm_type else True
|
|
|
- self.conv = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, g=1, bias=add_bias)
|
|
|
- self.norm = get_norm(norm_type, out_dim)
|
|
|
+ self.depthwise = depthwise
|
|
|
+ if not depthwise:
|
|
|
+ self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
|
|
|
+ self.norm = get_norm(norm_type, out_dim)
|
|
|
+ else:
|
|
|
+ self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
|
|
|
+ self.norm1 = get_norm(norm_type, in_dim)
|
|
|
+ self.conv2 = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
|
|
|
+ self.norm2 = get_norm(norm_type, out_dim)
|
|
|
self.act = get_activation(act_type)
|
|
|
|
|
|
def forward(self, x):
|
|
|
- return self.act(self.norm(self.conv(x)))
|
|
|
+ if not self.depthwise:
|
|
|
+ return self.act(self.norm(self.conv(x)))
|
|
|
+ else:
|
|
|
+ return self.act(self.norm2(self.conv2(self.norm1(self.conv1(x)))))
|
|
|
|
|
|
|
|
|
# ----------------- CNN Modules -----------------
|