import torch import torch.nn as nn from torch import Tensor from typing import List, Optional, Callable # ----------------- CNN modules ----------------- def get_conv2d(c1, c2, k, p, s, d, g, bias=False): conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias) return conv def get_activation(act_type=None): if act_type == 'relu': return nn.ReLU(inplace=True) elif act_type == 'lrelu': return nn.LeakyReLU(0.1, inplace=True) elif act_type == 'mish': return nn.Mish(inplace=True) elif act_type == 'silu': return nn.SiLU(inplace=True) elif act_type is None: return nn.Identity() else: raise NotImplementedError def get_norm(norm_type, dim): if norm_type == 'BN': return nn.BatchNorm2d(dim) elif norm_type == 'GN': return nn.GroupNorm(num_groups=32, num_channels=dim) elif norm_type is None: return nn.Identity() else: raise NotImplementedError def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation, ) def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class Conv(nn.Module): def __init__(self, c1, # in channels c2, # out channels k=1, # kernel size p=0, # padding s=1, # padding d=1, # dilation act_type :str = 'lrelu', # activation norm_type :str ='BN', # normalization depthwise :bool =False): super(Conv, self).__init__() convs = [] add_bias = False if norm_type else True if depthwise: convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias)) # depthwise conv if norm_type: convs.append(get_norm(norm_type, c1)) if act_type: convs.append(get_activation(act_type)) # pointwise conv convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias)) if norm_type: convs.append(get_norm(norm_type, c2)) if act_type: convs.append(get_activation(act_type)) else: convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias)) if norm_type: convs.append(get_norm(norm_type, c2)) if act_type: convs.append(get_activation(act_type)) self.convs = nn.Sequential(*convs) def forward(self, x): return self.convs(x) class BasicBlock(nn.Module): expansion: int = 1 def __init__( self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): expansion: int = 4 def __init__( self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out # ----------------- Transformer modules -----------------