import torch import torch.nn as nn import torch.nn.functional as F class Conv(nn.Module): def __init__(self, in_dim, out_dim, k, s=1, p=0, d=1, g=1, act=True): super(Conv, self).__init__() self.convs = nn.Sequential( nn.Conv2d(in_dim, out_dim, k, stride=s, padding=p, dilation=d, groups=g), nn.BatchNorm2d(out_dim), nn.LeakyReLU(0.1, inplace=True) if act else nn.Identity() ) def forward(self, x): return self.convs(x)