yolof_encoder.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolof_basic import BasicConv
  5. except:
  6. from yolof_basic import BasicConv
  7. # BottleNeck
  8. class Bottleneck(nn.Module):
  9. def __init__(self,
  10. in_dim :int,
  11. out_dim :int,
  12. dilation :int,
  13. expand_ratio :float = 0.5,
  14. shortcut :bool = False,
  15. act_type :str = 'relu',
  16. norm_type :str = 'BN',
  17. depthwise :bool = False,):
  18. super(Bottleneck, self).__init__()
  19. # ------------------ Basic parameters -------------------
  20. self.in_dim = in_dim
  21. self.out_dim = out_dim
  22. self.dilation = dilation
  23. self.expand_ratio = expand_ratio
  24. self.shortcut = shortcut and in_dim == out_dim
  25. inter_dim = round(in_dim * expand_ratio)
  26. # ------------------ Network parameters -------------------
  27. self.branch = nn.Sequential(
  28. BasicConv(in_dim, inter_dim,
  29. kernel_size=1, padding=0, stride=1,
  30. act_type=act_type, norm_type=norm_type),
  31. BasicConv(inter_dim, inter_dim,
  32. kernel_size=3, padding=dilation, dilation=dilation, stride=1,
  33. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  34. BasicConv(inter_dim, in_dim,
  35. kernel_size=1, padding=0, stride=1,
  36. act_type=act_type, norm_type=norm_type)
  37. )
  38. def forward(self, x):
  39. h = self.branch(x)
  40. return x + self.branch(x) if self.shortcut else h
  41. # ELAN-style Dilated Encoder
  42. class YolofEncoder(nn.Module):
  43. def __init__(self, cfg, in_dim, out_dim):
  44. super(YolofEncoder, self).__init__()
  45. # ------------------ Basic parameters -------------------
  46. self.in_dim = in_dim
  47. self.out_dim = out_dim
  48. self.expand_ratio = cfg.neck_expand_ratio
  49. self.dilations = cfg.neck_dilations
  50. # ------------------ Network parameters -------------------
  51. ## input layer
  52. self.input_proj = BasicConv(in_dim, out_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm)
  53. ## dilated layers
  54. self.module = nn.ModuleList([Bottleneck(in_dim = out_dim,
  55. out_dim = out_dim,
  56. dilation = dilation,
  57. expand_ratio = self.expand_ratio,
  58. shortcut = True,
  59. act_type = cfg.neck_act,
  60. norm_type = cfg.neck_norm,
  61. depthwise = cfg.neck_depthwise,
  62. ) for dilation in self.dilations])
  63. ## output layer
  64. self.output_proj = BasicConv(out_dim * (len(self.dilations) + 1), out_dim,
  65. kernel_size=1, padding=0, stride=1,
  66. act_type=cfg.neck_act, norm_type=cfg.neck_norm)
  67. # Initialize all layers
  68. self.init_weights()
  69. def init_weights(self):
  70. """Initialize the parameters."""
  71. for m in self.modules():
  72. if isinstance(m, torch.nn.Conv2d):
  73. # In order to be consistent with the source code,
  74. # reset the Conv2d initialization parameters
  75. m.reset_parameters()
  76. def forward(self, x):
  77. x = self.input_proj(x)
  78. out = [x]
  79. for m in self.module:
  80. x = m(x)
  81. out.append(x)
  82. out = self.output_proj(torch.cat(out, dim=1))
  83. return out