cspdarknet.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import BasicConv, CSPBlock
  5. except:
  6. from modules import BasicConv, CSPBlock
  7. # ---------------------------- CSPDarkNet ----------------------------
  8. # CSPDarkNet
  9. class CSPDarkNet(nn.Module):
  10. def __init__(self, img_dim=3, width=1.0, depth=1.0, act_type='silu', norm_type='BN', depthwise=False, num_classes=1000):
  11. super(CSPDarkNet, self).__init__()
  12. # ---------------- Basic parameters ----------------
  13. self.width_factor = width
  14. self.depth_factor = depth
  15. self.feat_dims = [round(64 * width),
  16. round(128 * width),
  17. round(256 * width),
  18. round(512 * width),
  19. round(1024 * width)
  20. ]
  21. # ---------------- Model parameters ----------------
  22. ## P1/2
  23. self.layer_1 = BasicConv(img_dim, self.feat_dims[0],
  24. kernel_size=6, padding=2, stride=2,
  25. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  26. ## P2/4
  27. self.layer_2 = nn.Sequential(
  28. BasicConv(self.feat_dims[0], self.feat_dims[1],
  29. kernel_size=3, padding=1, stride=2,
  30. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  31. CSPBlock(self.feat_dims[1],
  32. self.feat_dims[1],
  33. num_blocks = round(3*depth),
  34. expand_ratio = 0.5,
  35. shortcut = True,
  36. act_type = act_type,
  37. norm_type = norm_type,
  38. depthwise = depthwise)
  39. )
  40. # P3/8
  41. self.layer_3 = nn.Sequential(
  42. BasicConv(self.feat_dims[1], self.feat_dims[2],
  43. kernel_size=3, padding=1, stride=2,
  44. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  45. CSPBlock(self.feat_dims[2],
  46. self.feat_dims[2],
  47. num_blocks = round(9*depth),
  48. expand_ratio = 0.5,
  49. shortcut = True,
  50. act_type = act_type,
  51. norm_type = norm_type,
  52. depthwise = depthwise)
  53. )
  54. # P4/16
  55. self.layer_4 = nn.Sequential(
  56. BasicConv(self.feat_dims[2], self.feat_dims[3],
  57. kernel_size=3, padding=1, stride=2,
  58. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  59. CSPBlock(self.feat_dims[3],
  60. self.feat_dims[3],
  61. num_blocks = round(9*depth),
  62. expand_ratio = 0.5,
  63. shortcut = True,
  64. act_type = act_type,
  65. norm_type = norm_type,
  66. depthwise = depthwise)
  67. )
  68. # P5/32
  69. self.layer_5 = nn.Sequential(
  70. BasicConv(self.feat_dims[3], self.feat_dims[4],
  71. kernel_size=3, padding=1, stride=2,
  72. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  73. CSPBlock(self.feat_dims[4],
  74. self.feat_dims[4],
  75. num_blocks = round(3*depth),
  76. expand_ratio = 0.5,
  77. shortcut = True,
  78. act_type = act_type,
  79. norm_type = norm_type,
  80. depthwise = depthwise)
  81. )
  82. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  83. self.fc = nn.Linear(self.feat_dims[4], num_classes)
  84. def forward(self, x):
  85. c1 = self.layer_1(x)
  86. c2 = self.layer_2(c1)
  87. c3 = self.layer_3(c2)
  88. c4 = self.layer_4(c3)
  89. c5 = self.layer_5(c4)
  90. c5 = self.avgpool(c5)
  91. c5 = torch.flatten(c5, 1)
  92. c5 = self.fc(c5)
  93. return c5
  94. # ---------------------------- Functions ----------------------------
  95. ## build ELAN-Net
  96. # ------------------------ Model Functions ------------------------
  97. def cspdarknet_n(img_dim=3, num_classes=1000) -> CSPDarkNet:
  98. return CSPDarkNet(img_dim=img_dim,
  99. width=0.25,
  100. depth=0.34,
  101. act_type='silu',
  102. norm_type='BN',
  103. depthwise=False,
  104. num_classes=num_classes
  105. )
  106. def cspdarknet_s(img_dim=3, num_classes=1000) -> CSPDarkNet:
  107. return CSPDarkNet(img_dim=img_dim,
  108. width=0.50,
  109. depth=0.34,
  110. act_type='silu',
  111. norm_type='BN',
  112. depthwise=False,
  113. num_classes=num_classes
  114. )
  115. def cspdarknet_m(img_dim=3, num_classes=1000) -> CSPDarkNet:
  116. return CSPDarkNet(img_dim=img_dim,
  117. width=0.75,
  118. depth=0.67,
  119. act_type='silu',
  120. norm_type='BN',
  121. depthwise=False,
  122. num_classes=num_classes
  123. )
  124. def cspdarknet_l(img_dim=3, num_classes=1000) -> CSPDarkNet:
  125. return CSPDarkNet(img_dim=img_dim,
  126. width=1.0,
  127. depth=1.0,
  128. act_type='silu',
  129. norm_type='BN',
  130. depthwise=False,
  131. num_classes=num_classes
  132. )
  133. def cspdarknet_x(img_dim=3, num_classes=1000) -> CSPDarkNet:
  134. return CSPDarkNet(img_dim=img_dim,
  135. width=1.25,
  136. depth=1.34,
  137. act_type='silu',
  138. norm_type='BN',
  139. depthwise=False,
  140. num_classes=num_classes
  141. )
  142. if __name__ == '__main__':
  143. import torch
  144. from thop import profile
  145. # build model
  146. model = cspdarknet_s()
  147. x = torch.randn(1, 3, 224, 224)
  148. print('==============================')
  149. flops, params = profile(model, inputs=(x, ), verbose=False)
  150. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  151. print('Params : {:.2f} M'.format(params / 1e6))