yolof_encoder.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import ConvModule
  5. except:
  6. from modules import ConvModule
  7. # BottleNeck
  8. class Bottleneck(nn.Module):
  9. def __init__(self, in_dim: int, dilation: int = 1, expansion: float = 0.5):
  10. super(Bottleneck, self).__init__()
  11. # ------------------ Basic parameters -------------------
  12. self.in_dim = in_dim
  13. self.dilation = dilation
  14. self.expansion = expansion
  15. inter_dim = round(in_dim * expansion)
  16. # ------------------ Network parameters -------------------
  17. self.branch = nn.Sequential(
  18. ConvModule(in_dim, inter_dim, kernel_size=1),
  19. ConvModule(inter_dim, inter_dim, kernel_size=3, padding=dilation, dilation=dilation),
  20. ConvModule(inter_dim, in_dim, kernel_size=1)
  21. )
  22. def forward(self, x):
  23. return x + self.branch(x)
  24. # Dilated Encoder
  25. class DilatedEncoder(nn.Module):
  26. def __init__(self, cfg, in_dim, out_dim):
  27. super(DilatedEncoder, self).__init__()
  28. # ------------------ Basic parameters -------------------
  29. self.in_dim = in_dim
  30. self.out_dim = out_dim
  31. self.expand_ratio = cfg.neck_expand_ratio
  32. self.dilations = cfg.neck_dilations
  33. # ------------------ Network parameters -------------------
  34. ## proj layer
  35. self.projector = nn.Sequential(
  36. ConvModule(in_dim, out_dim, kernel_size=1, use_act=False),
  37. ConvModule(out_dim, out_dim, kernel_size=3, padding=1, use_act=False)
  38. )
  39. ## encoder layers
  40. self.encoders = nn.Sequential(
  41. *[Bottleneck(in_dim = out_dim,
  42. dilation = d,
  43. expansion = self.expand_ratio,
  44. ) for d in self.dilations])
  45. def forward(self, x):
  46. x = self.projector(x)
  47. x = self.encoders(x)
  48. return x
  49. if __name__=='__main__':
  50. from thop import profile
  51. # YOLOv1 configuration
  52. class YolofBaseConfig(object):
  53. def __init__(self) -> None:
  54. # ---------------- Model config ----------------
  55. self.out_stride = 32
  56. ## Backbone
  57. self.backbone = 'resnet18'
  58. self.use_pretrained = True
  59. self.neck_expand_ratio = 0.25
  60. self.neck_dilations = [2, 4, 6, 8]
  61. cfg = YolofBaseConfig()
  62. # Randomly generate a input data
  63. x = torch.randn(2, 512, 20, 20)
  64. # Build backbone
  65. model = DilatedEncoder(cfg, in_dim=512, out_dim=512)
  66. # Inference
  67. output = model(x)
  68. print(' - the shape of input : ', x.shape)
  69. print(' - the shape of output : ', output.shape)
  70. x = torch.randn(1, 512, 20, 20)
  71. flops, params = profile(model, inputs=(x, ), verbose=False)
  72. print('============== FLOPs & Params ================')
  73. print(' - FLOPs : {:.2f} G'.format(flops / 1e9 * 2))
  74. print(' - Params : {:.2f} M'.format(params / 1e6))