import torch import torch.nn as nn from typing import List # --------------------- Basic modules --------------------- class ConvModule(nn.Module): def __init__(self, in_dim, # in channels out_dim, # out channels kernel_size=1, # kernel size padding=0, # padding stride=1, # padding dilation=1, # dilation ): super(ConvModule, self).__init__() self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False) self.norm = nn.BatchNorm2d(out_dim) self.act = nn.LeakyReLU(0.1, inplace=True) def forward(self, x): return self.act(self.norm(self.conv(x)))