modules.py 919 B

1234567891011121314151617181920212223
  1. import torch
  2. import torch.nn as nn
  3. from typing import List
  4. # --------------------- Basic modules ---------------------
  5. class ConvModule(nn.Module):
  6. def __init__(self,
  7. in_dim: int, # in channels
  8. out_dim: int, # out channels
  9. kernel_size: int = 1, # kernel size
  10. padding: int = 0, # padding
  11. stride: int = 1, # padding
  12. dilation: int = 1, # dilation
  13. use_act: bool = False,
  14. ):
  15. super(ConvModule, self).__init__()
  16. self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False)
  17. self.norm = nn.BatchNorm2d(out_dim)
  18. self.act = nn.ReLU(inplace=True) if use_act else nn.Identity()
  19. def forward(self, x):
  20. return self.act(self.norm(self.conv(x)))