yolof_encoder.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolof_basic import BasicConv
  5. except:
  6. from yolof_basic import BasicConv
  7. # BottleNeck
  8. class Bottleneck(nn.Module):
  9. def __init__(self,
  10. in_dim :int,
  11. out_dim :int,
  12. dilation :int,
  13. expand_ratio :float = 0.5,
  14. shortcut :bool = False,
  15. act_type :str = 'relu',
  16. norm_type :str = 'BN',
  17. depthwise :bool = False,):
  18. super(Bottleneck, self).__init__()
  19. # ------------------ Basic parameters -------------------
  20. self.in_dim = in_dim
  21. self.out_dim = out_dim
  22. self.dilation = dilation
  23. self.expand_ratio = expand_ratio
  24. self.shortcut = shortcut and in_dim == out_dim
  25. inter_dim = round(in_dim * expand_ratio)
  26. # ------------------ Network parameters -------------------
  27. self.branch = nn.Sequential(
  28. BasicConv(in_dim, inter_dim,
  29. kernel_size=1, padding=0, stride=1,
  30. act_type=act_type, norm_type=norm_type),
  31. BasicConv(inter_dim, inter_dim,
  32. kernel_size=3, padding=dilation, dilation=dilation, stride=1,
  33. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  34. BasicConv(inter_dim, in_dim,
  35. kernel_size=1, padding=0, stride=1,
  36. act_type=act_type, norm_type=norm_type)
  37. )
  38. def forward(self, x):
  39. h = self.branch(x)
  40. return x + self.branch(x) if self.shortcut else h
  41. # CSP-style Dilated Encoder
  42. class YolofEncoder(nn.Module):
  43. def __init__(self, cfg, in_dim, out_dim):
  44. super(YolofEncoder, self).__init__()
  45. # ------------------ Basic parameters -------------------
  46. self.in_dim = in_dim
  47. self.out_dim = out_dim
  48. self.expand_ratio = cfg.neck_expand_ratio
  49. self.dilations = cfg.neck_dilations
  50. # ------------------ Network parameters -------------------
  51. ## proj layer
  52. self.projector = nn.Sequential(
  53. BasicConv(in_dim, out_dim, kernel_size=1, act_type=None, norm_type=cfg.neck_norm),
  54. BasicConv(out_dim, out_dim, kernel_size=3, padding=1, act_type=None, norm_type=cfg.neck_norm)
  55. )
  56. ## encoder layers
  57. self.encoders = nn.Sequential(*[Bottleneck(in_dim = out_dim,
  58. out_dim = out_dim,
  59. dilation = d,
  60. expand_ratio = self.expand_ratio,
  61. shortcut = True,
  62. act_type = cfg.neck_act,
  63. norm_type = cfg.neck_norm,
  64. depthwise = cfg.neck_depthwise,
  65. ) for d in self.dilations])
  66. # Initialize all layers
  67. self.init_weights()
  68. def init_weights(self):
  69. """Initialize the parameters."""
  70. for m in self.modules():
  71. if isinstance(m, torch.nn.Conv2d):
  72. # In order to be consistent with the source code,
  73. # reset the Conv2d initialization parameters
  74. m.reset_parameters()
  75. def forward(self, x):
  76. x = self.projector(x)
  77. x = self.encoders(x)
  78. return x