ctrnet_decoder.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import math
  2. import torch.nn as nn
  3. from .ctrnet_basic import DeConv, DeformableConv
  4. def build_decoder(cfg, in_dim, out_dim):
  5. return CTRDecoder(in_dim = in_dim,
  6. out_dim = out_dim,
  7. max_stride = cfg['max_stride'],
  8. out_stride = cfg['out_stride'],
  9. act_type = cfg['dec_act'],
  10. norm_type = cfg['dec_norm'],
  11. depthwise = cfg['dec_depthwise']
  12. )
  13. class CTRDecoder(nn.Module):
  14. def __init__(self,
  15. in_dim :int,
  16. out_dim :int,
  17. max_stride :int,
  18. out_stride :int,
  19. act_type :str,
  20. norm_type :str,
  21. depthwise :bool
  22. ):
  23. super().__init__()
  24. # ---------- Basic parameters ----------
  25. self.in_dim = in_dim
  26. self.out_dim = out_dim
  27. self.out_stride = out_stride
  28. self.num_layers = round(math.log2(max_stride // out_stride))
  29. # ---------- Network parameters ----------
  30. layers = []
  31. for i in range(self.num_layers):
  32. layer = nn.Sequential(
  33. DeformableConv(in_dim, out_dim[i], kernel_size=3, padding=1, stride=1),
  34. DeConv(out_dim[i], out_dim[i], kernel_size=4, stride=2, act_type=act_type, norm_type=norm_type)
  35. )
  36. layers.append(layer)
  37. in_dim = out_dim[i]
  38. self.layers = nn.Sequential(*layers)
  39. def forward(self, x):
  40. return self.layers(x)