gelan.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import BasicConv, RepGElanLayer, ADown
  5. except:
  6. from modules import BasicConv, RepGElanLayer, ADown
  7. # ---------------------------- GELAN Backbone ----------------------------
  8. class GElanCBackbone(nn.Module):
  9. def __init__(self, img_dim=3, num_classes=1000, act_type='silu', norm_type='BN', depthwise=False):
  10. super(GElanCBackbone, self).__init__()
  11. # ------------------ Basic setting ------------------
  12. self.feat_dims = {
  13. "c1": [64],
  14. "c2": [128, [128, 64], 256],
  15. "c3": [256, [256, 128], 512],
  16. "c4": [512, [512, 256], 512],
  17. "c5": [512, [512, 256], 512],
  18. }
  19. # ------------------ Network setting ------------------
  20. ## P1/2
  21. self.layer_1 = BasicConv(img_dim, self.feat_dims["c1"][0],
  22. kernel_size=3, padding=1, stride=2,
  23. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  24. # P2/4
  25. self.layer_2 = nn.Sequential(
  26. BasicConv(self.feat_dims["c1"][0], self.feat_dims["c2"][0],
  27. kernel_size=3, padding=1, stride=2,
  28. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  29. RepGElanLayer(in_dim = self.feat_dims["c2"][0],
  30. inter_dims = self.feat_dims["c2"][1],
  31. out_dim = self.feat_dims["c2"][2],
  32. num_blocks = 1,
  33. shortcut = True,
  34. act_type = act_type,
  35. norm_type = norm_type,
  36. depthwise = depthwise)
  37. )
  38. # P3/8
  39. self.layer_3 = nn.Sequential(
  40. ADown(self.feat_dims["c2"][2], self.feat_dims["c3"][0],
  41. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  42. RepGElanLayer(in_dim = self.feat_dims["c3"][0],
  43. inter_dims = self.feat_dims["c3"][1],
  44. out_dim = self.feat_dims["c3"][2],
  45. num_blocks = 1,
  46. shortcut = True,
  47. act_type = act_type,
  48. norm_type = norm_type,
  49. depthwise = depthwise)
  50. )
  51. # P4/16
  52. self.layer_4 = nn.Sequential(
  53. ADown(self.feat_dims["c3"][2], self.feat_dims["c4"][0],
  54. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  55. RepGElanLayer(in_dim = self.feat_dims["c4"][0],
  56. inter_dims = self.feat_dims["c4"][1],
  57. out_dim = self.feat_dims["c4"][2],
  58. num_blocks = 1,
  59. shortcut = True,
  60. act_type = act_type,
  61. norm_type = norm_type,
  62. depthwise = depthwise)
  63. )
  64. # P5/32
  65. self.layer_5 = nn.Sequential(
  66. ADown(self.feat_dims["c4"][2], self.feat_dims["c5"][0],
  67. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  68. RepGElanLayer(in_dim = self.feat_dims["c5"][0],
  69. inter_dims = self.feat_dims["c5"][1],
  70. out_dim = self.feat_dims["c5"][2],
  71. num_blocks = 1,
  72. shortcut = True,
  73. act_type = act_type,
  74. norm_type = norm_type,
  75. depthwise = depthwise)
  76. )
  77. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  78. self.fc = nn.Linear(self.feat_dims["c5"][2], num_classes)
  79. # Initialize all layers
  80. self.init_weights()
  81. def init_weights(self):
  82. """Initialize the parameters."""
  83. for m in self.modules():
  84. if isinstance(m, torch.nn.Conv2d):
  85. # In order to be consistent with the source code,
  86. # reset the Conv2d initialization parameters
  87. m.reset_parameters()
  88. def forward(self, x):
  89. c1 = self.layer_1(x)
  90. c2 = self.layer_2(c1)
  91. c3 = self.layer_3(c2)
  92. c4 = self.layer_4(c3)
  93. c5 = self.layer_5(c4)
  94. c5 = self.avgpool(c5)
  95. c5 = torch.flatten(c5, 1)
  96. c5 = self.fc(c5)
  97. return c5
  98. class GElanSBackbone(nn.Module):
  99. def __init__(self, img_dim=3, num_classes=1000, act_type='silu', norm_type='BN', depthwise=False):
  100. super(GElanSBackbone, self).__init__()
  101. # ------------------ Basic setting ------------------
  102. self.feat_dims = {
  103. "c1": [32],
  104. "c2": [64, [64, 32], 64],
  105. "c3": [64, [64, 32], 128],
  106. "c4": [128, [128, 64], 256],
  107. "c5": [256, [256, 128], 256],
  108. }
  109. # ------------------ Network setting ------------------
  110. ## P1/2
  111. self.layer_1 = BasicConv(img_dim, self.feat_dims["c1"][0],
  112. kernel_size=3, padding=1, stride=2,
  113. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  114. # P2/4
  115. self.layer_2 = nn.Sequential(
  116. BasicConv(self.feat_dims["c1"][0], self.feat_dims["c2"][0],
  117. kernel_size=3, padding=1, stride=2,
  118. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  119. RepGElanLayer(in_dim = self.feat_dims["c2"][0],
  120. inter_dims = self.feat_dims["c2"][1],
  121. out_dim = self.feat_dims["c2"][2],
  122. num_blocks = 3,
  123. shortcut = True,
  124. act_type = act_type,
  125. norm_type = norm_type,
  126. depthwise = depthwise)
  127. )
  128. # P3/8
  129. self.layer_3 = nn.Sequential(
  130. ADown(self.feat_dims["c2"][2], self.feat_dims["c3"][0],
  131. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  132. RepGElanLayer(in_dim = self.feat_dims["c3"][0],
  133. inter_dims = self.feat_dims["c3"][1],
  134. out_dim = self.feat_dims["c3"][2],
  135. num_blocks = 3,
  136. shortcut = True,
  137. act_type = act_type,
  138. norm_type = norm_type,
  139. depthwise = depthwise)
  140. )
  141. # P4/16
  142. self.layer_4 = nn.Sequential(
  143. ADown(self.feat_dims["c3"][2], self.feat_dims["c4"][0],
  144. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  145. RepGElanLayer(in_dim = self.feat_dims["c4"][0],
  146. inter_dims = self.feat_dims["c4"][1],
  147. out_dim = self.feat_dims["c4"][2],
  148. num_blocks = 3,
  149. shortcut = True,
  150. act_type = act_type,
  151. norm_type = norm_type,
  152. depthwise = depthwise)
  153. )
  154. # P5/32
  155. self.layer_5 = nn.Sequential(
  156. ADown(self.feat_dims["c4"][2], self.feat_dims["c5"][0],
  157. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  158. RepGElanLayer(in_dim = self.feat_dims["c5"][0],
  159. inter_dims = self.feat_dims["c5"][1],
  160. out_dim = self.feat_dims["c5"][2],
  161. num_blocks = 3,
  162. shortcut = True,
  163. act_type = act_type,
  164. norm_type = norm_type,
  165. depthwise = depthwise)
  166. )
  167. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  168. self.fc = nn.Linear(self.feat_dims["c5"][2], num_classes)
  169. # Initialize all layers
  170. self.init_weights()
  171. def init_weights(self):
  172. """Initialize the parameters."""
  173. for m in self.modules():
  174. if isinstance(m, torch.nn.Conv2d):
  175. # In order to be consistent with the source code,
  176. # reset the Conv2d initialization parameters
  177. m.reset_parameters()
  178. def forward(self, x):
  179. c1 = self.layer_1(x)
  180. c2 = self.layer_2(c1)
  181. c3 = self.layer_3(c2)
  182. c4 = self.layer_4(c3)
  183. c5 = self.layer_5(c4)
  184. c5 = self.avgpool(c5)
  185. c5 = torch.flatten(c5, 1)
  186. c5 = self.fc(c5)
  187. return c5
  188. # ---------------------------- Functions ----------------------------
  189. def gelan_c(img_dim=3, num_classes=1000):
  190. return GElanCBackbone(img_dim,
  191. num_classes=num_classes,
  192. act_type='silu',
  193. norm_type='BN',
  194. depthwise=False)
  195. def gelan_s(img_dim=3, num_classes=1000):
  196. return GElanSBackbone(img_dim,
  197. num_classes=num_classes,
  198. act_type='silu',
  199. norm_type='BN',
  200. depthwise=False)
  201. if __name__ == '__main__':
  202. import torch
  203. from thop import profile
  204. # build model
  205. model = gelan_c()
  206. x = torch.randn(1, 3, 224, 224)
  207. print('==============================')
  208. flops, params = profile(model, inputs=(x, ), verbose=False)
  209. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  210. print('Params : {:.2f} M'.format(params / 1e6))