| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- import torch
- import torch.nn as nn
- from typing import List
- # --------------------- Basic modules ---------------------
- class ConvModule(nn.Module):
- def __init__(self,
- in_dim, # in channels
- out_dim, # out channels
- kernel_size=1, # kernel size
- padding=0, # padding
- stride=1, # padding
- dilation=1, # dilation
- ):
- super(ConvModule, self).__init__()
- self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False)
- self.norm = nn.BatchNorm2d(out_dim)
- self.act = nn.SiLU(inplace=True)
- def forward(self, x):
- return self.act(self.norm(self.conv(x)))
- class YoloBottleneck(nn.Module):
- def __init__(self,
- in_dim :int,
- out_dim :int,
- kernel_size :List = [1, 3],
- expansion :float = 0.5,
- shortcut :bool = False,
- ):
- super(YoloBottleneck, self).__init__()
- inter_dim = int(out_dim * expansion)
- # ----------------- Network setting -----------------
- self.conv_layer1 = ConvModule(in_dim, inter_dim, kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1)
- self.conv_layer2 = ConvModule(inter_dim, out_dim, kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1)
- self.shortcut = shortcut and in_dim == out_dim
- def forward(self, x):
- h = self.conv_layer2(self.conv_layer1(x))
- return x + h if self.shortcut else h
- class C2fBlock(nn.Module):
- def __init__(self,
- in_dim: int,
- out_dim: int,
- expansion : float = 0.5,
- num_blocks : int = 1,
- shortcut : bool = False,
- ):
- super(C2fBlock, self).__init__()
- inter_dim = round(out_dim * expansion)
- self.input_proj = ConvModule(in_dim, inter_dim * 2, kernel_size=1)
- self.output_proj = ConvModule((2 + num_blocks) * inter_dim, out_dim, kernel_size=1)
- self.module = nn.ModuleList([
- YoloBottleneck(in_dim = inter_dim,
- out_dim = inter_dim,
- kernel_size = [3, 3],
- expansion = 1.0,
- shortcut = shortcut,
- ) for _ in range(num_blocks)])
- def forward(self, x):
- # Input proj
- x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
- out = list([x1, x2])
- # Bottlenecl
- out.extend(m(out[-1]) for m in self.module)
- # Output proj
- out = self.output_proj(torch.cat(out, dim=1))
- return out
|