| 1234567891011121314151617181920212223 |
- import torch
- import torch.nn as nn
- from typing import List
- # --------------------- Basic modules ---------------------
- class ConvModule(nn.Module):
- def __init__(self,
- in_dim, # in channels
- out_dim, # out channels
- kernel_size=1, # kernel size
- padding=0, # padding
- stride=1, # padding
- dilation=1, # dilation
- ):
- super(ConvModule, self).__init__()
- self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False)
- self.norm = nn.BatchNorm2d(out_dim)
- self.act = nn.LeakyReLU(0.1, inplace=True)
- def forward(self, x):
- return self.act(self.norm(self.conv(x)))
|