elandarknet.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import BasicConv, ELANLayer
  5. except:
  6. from modules import BasicConv, ELANLayer
  7. ## ELAN-based DarkNet
  8. class ELANDarkNet(nn.Module):
  9. def __init__(self, img_dim=3, width=1.0, depth=1.0, ratio=1.0, num_classes=1000, act_type='silu', norm_type='BN', depthwise=False):
  10. super(ELANDarkNet, self).__init__()
  11. # ---------------- Basic parameters ----------------
  12. self.width_factor = width
  13. self.depth_factor = depth
  14. self.last_stage_factor = ratio
  15. self.feat_dims = [round(64 * width),
  16. round(128 * width),
  17. round(256 * width),
  18. round(512 * width),
  19. round(512 * width * ratio)
  20. ]
  21. # ---------------- Network parameters ----------------
  22. ## P1/2
  23. self.layer_1 = BasicConv(img_dim, self.feat_dims[0],
  24. kernel_size=3, padding=1, stride=2,
  25. act_type=act_type, norm_type=norm_type)
  26. ## P2/4
  27. self.layer_2 = nn.Sequential(
  28. BasicConv(self.feat_dims[0], self.feat_dims[1],
  29. kernel_size=3, padding=1, stride=2,
  30. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  31. ELANLayer(in_dim = self.feat_dims[1],
  32. out_dim = self.feat_dims[1],
  33. num_blocks = round(3*depth),
  34. shortcut = True,
  35. act_type = act_type,
  36. norm_type = norm_type,
  37. depthwise = depthwise,
  38. )
  39. )
  40. ## P3/8
  41. self.layer_3 = nn.Sequential(
  42. BasicConv(self.feat_dims[1], self.feat_dims[2],
  43. kernel_size=3, padding=1, stride=2,
  44. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  45. ELANLayer(in_dim = self.feat_dims[2],
  46. out_dim = self.feat_dims[2],
  47. num_blocks = round(6*depth),
  48. shortcut = True,
  49. act_type = act_type,
  50. norm_type = norm_type,
  51. depthwise = depthwise,
  52. )
  53. )
  54. ## P4/16
  55. self.layer_4 = nn.Sequential(
  56. BasicConv(self.feat_dims[2], self.feat_dims[3],
  57. kernel_size=3, padding=1, stride=2,
  58. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  59. ELANLayer(in_dim = self.feat_dims[3],
  60. out_dim = self.feat_dims[3],
  61. num_blocks = round(6*depth),
  62. shortcut = True,
  63. act_type = act_type,
  64. norm_type = norm_type,
  65. depthwise = depthwise,
  66. )
  67. )
  68. ## P5/32
  69. self.layer_5 = nn.Sequential(
  70. BasicConv(self.feat_dims[3], self.feat_dims[4],
  71. kernel_size=3, padding=1, stride=2,
  72. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  73. ELANLayer(in_dim = self.feat_dims[4],
  74. out_dim = self.feat_dims[4],
  75. num_blocks = round(3*depth),
  76. shortcut = True,
  77. act_type = act_type,
  78. norm_type = norm_type,
  79. depthwise = depthwise,
  80. )
  81. )
  82. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  83. self.fc = nn.Linear(self.feat_dims[4], num_classes)
  84. def forward(self, x):
  85. c1 = self.layer_1(x)
  86. c2 = self.layer_2(c1)
  87. c3 = self.layer_3(c2)
  88. c4 = self.layer_4(c3)
  89. c5 = self.layer_5(c4)
  90. c5 = self.avgpool(c5)
  91. c5 = torch.flatten(c5, 1)
  92. c5 = self.fc(c5)
  93. return c5
  94. # ------------------------ Model Functions ------------------------
  95. def elandarknet_n(img_dim=3, num_classes=1000) -> ELANDarkNet:
  96. return ELANDarkNet(img_dim=img_dim,
  97. width=0.25,
  98. depth=0.34,
  99. ratio=2.0,
  100. act_type='silu',
  101. norm_type='BN',
  102. depthwise=False,
  103. num_classes=num_classes
  104. )
  105. def elandarknet_s(img_dim=3, num_classes=1000) -> ELANDarkNet:
  106. return ELANDarkNet(img_dim=img_dim,
  107. width=0.50,
  108. depth=0.34,
  109. ratio=2.0,
  110. act_type='silu',
  111. norm_type='BN',
  112. depthwise=False,
  113. num_classes=num_classes
  114. )
  115. def elandarknet_m(img_dim=3, num_classes=1000) -> ELANDarkNet:
  116. return ELANDarkNet(img_dim=img_dim,
  117. width=0.75,
  118. depth=0.67,
  119. ratio=1.5,
  120. act_type='silu',
  121. norm_type='BN',
  122. depthwise=False,
  123. num_classes=num_classes
  124. )
  125. def elandarknet_l(img_dim=3, num_classes=1000) -> ELANDarkNet:
  126. return ELANDarkNet(img_dim=img_dim,
  127. width=1.0,
  128. depth=1.0,
  129. ratio=1.0,
  130. act_type='silu',
  131. norm_type='BN',
  132. depthwise=False,
  133. num_classes=num_classes
  134. )
  135. def elandarknet_x(img_dim=3, num_classes=1000) -> ELANDarkNet:
  136. return ELANDarkNet(img_dim=img_dim,
  137. width=1.25,
  138. depth=1.34,
  139. ratio=1.0,
  140. act_type='silu',
  141. norm_type='BN',
  142. depthwise=False,
  143. num_classes=num_classes
  144. )
  145. if __name__ == '__main__':
  146. import torch
  147. from thop import profile
  148. # build model
  149. model = elandarknet_s()
  150. x = torch.randn(1, 3, 224, 224)
  151. print('==============================')
  152. flops, params = profile(model, inputs=(x, ), verbose=False)
  153. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  154. print('Params : {:.2f} M'.format(params / 1e6))