import torch.nn as nn def get_conv2d(c1, c2, k, p, s, d, g): conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g) return conv def get_activation(act_type=None): if act_type is None: return nn.Identity() elif act_type == 'relu': return nn.ReLU(inplace=True) elif act_type == 'lrelu': return nn.LeakyReLU(0.1, inplace=True) elif act_type == 'mish': return nn.Mish(inplace=True) elif act_type == 'silu': return nn.SiLU(inplace=True) elif act_type == 'gelu': return nn.GELU() else: raise NotImplementedError(act_type) def get_norm(norm_type, dim): if norm_type == 'BN': return nn.BatchNorm2d(dim) elif norm_type == 'GN': return nn.GroupNorm(num_groups=32, num_channels=dim) elif norm_type is None: return nn.Identity() else: raise NotImplementedError(norm_type) # ----------------- CNN ops ----------------- class ConvModule(nn.Module): def __init__(self, c1, c2, k=1, p=0, s=1, d=1, act_type='relu', norm_type='BN', depthwise=False): super(ConvModule, self).__init__() convs = [] if depthwise: convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1)) # depthwise conv if norm_type: convs.append(get_norm(norm_type, c1)) if act_type: convs.append(get_activation(act_type)) # pointwise conv convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1)) if norm_type: convs.append(get_norm(norm_type, c2)) if act_type: convs.append(get_activation(act_type)) else: convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1)) if norm_type: convs.append(get_norm(norm_type, c2)) if act_type: convs.append(get_activation(act_type)) self.convs = nn.Sequential(*convs) def forward(self, x): return self.convs(x)