| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- import math
- import torch.nn as nn
- from .ctrnet_basic import DeConv, DeformableConv
- def build_decoder(cfg, in_dim, out_dim):
- return CTRDecoder(in_dim = in_dim,
- out_dim = out_dim,
- max_stride = cfg['max_stride'],
- out_stride = cfg['out_stride'],
- act_type = cfg['dec_act'],
- norm_type = cfg['dec_norm'],
- depthwise = cfg['dec_depthwise']
- )
- class CTRDecoder(nn.Module):
- def __init__(self,
- in_dim :int,
- out_dim :int,
- max_stride :int,
- out_stride :int,
- act_type :str,
- norm_type :str,
- depthwise :bool
- ):
- super().__init__()
- # ---------- Basic parameters ----------
- self.in_dim = in_dim
- self.out_dim = out_dim
- self.out_stride = out_stride
- self.num_layers = round(math.log2(max_stride // out_stride))
- # ---------- Network parameters ----------
- layers = []
- for i in range(self.num_layers):
- layer = nn.Sequential(
- DeformableConv(in_dim, out_dim[i], kernel_size=3, padding=1, stride=1),
- DeConv(out_dim[i], out_dim[i], kernel_size=4, stride=2, act_type=act_type, norm_type=norm_type)
- )
- layers.append(layer)
- in_dim = out_dim[i]
- self.layers = nn.Sequential(*layers)
- def forward(self, x):
- return self.layers(x)
|