| 12345678910111213141516 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class Conv(nn.Module):
- def __init__(self, in_dim, out_dim, k, s=1, p=0, d=1, g=1, act=True):
- super(Conv, self).__init__()
- self.convs = nn.Sequential(
- nn.Conv2d(in_dim, out_dim, k, stride=s, padding=p, dilation=d, groups=g),
- nn.BatchNorm2d(out_dim),
- nn.LeakyReLU(0.1, inplace=True) if act else nn.Identity()
- )
- def forward(self, x):
- return self.convs(x)
|