| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- import torch
- import torch.nn as nn
- from typing import List
- # --------------------- Basic modules ---------------------
- class ConvModule(nn.Module):
- def __init__(self,
- in_dim, # in channels
- out_dim, # out channels
- kernel_size=1, # kernel size
- padding=0, # padding
- stride=1, # padding
- dilation=1, # dilation
- ):
- super(ConvModule, self).__init__()
- self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False)
- self.norm = nn.BatchNorm2d(out_dim)
- self.act = nn.SiLU(inplace=True)
- def forward(self, x):
- return self.act(self.norm(self.conv(x)))
- class YoloBottleneck(nn.Module):
- def __init__(self,
- in_dim :int,
- out_dim :int,
- kernel_size :List = [1, 3],
- expansion :float = 0.5,
- shortcut :bool = False,
- ) -> None:
- super(YoloBottleneck, self).__init__()
- inter_dim = int(out_dim * expansion)
- # ----------------- Network setting -----------------
- self.conv_layer1 = ConvModule(in_dim, inter_dim, kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1)
- self.conv_layer2 = ConvModule(inter_dim, out_dim, kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1)
- self.shortcut = shortcut and in_dim == out_dim
- def forward(self, x):
- h = self.conv_layer2(self.conv_layer1(x))
- return x + h if self.shortcut else h
- class CSPBlock(nn.Module):
- def __init__(self,
- in_dim,
- out_dim,
- num_blocks :int = 1,
- expansion :float = 0.5,
- shortcut :bool = False,
- ):
- super(CSPBlock, self).__init__()
- # ---------- Basic parameters ----------
- self.num_blocks = num_blocks
- self.expansion = expansion
- self.shortcut = shortcut
- inter_dim = round(out_dim * expansion)
- # ---------- Model parameters ----------
- self.conv_layer_1 = ConvModule(in_dim, inter_dim, kernel_size=1)
- self.conv_layer_2 = ConvModule(in_dim, inter_dim, kernel_size=1)
- self.conv_layer_3 = ConvModule(inter_dim * 2, out_dim, kernel_size=1)
- self.module = nn.Sequential(*[
- YoloBottleneck(inter_dim,
- inter_dim,
- kernel_size = [1, 3],
- expansion = 1.0,
- shortcut = shortcut,
- ) for _ in range(num_blocks)])
- def forward(self, x):
- x1 = self.conv_layer_1(x)
- x2 = self.module(self.conv_layer_2(x))
- out = self.conv_layer_3(torch.cat([x1, x2], dim=1))
- return out
|