rtcdet_basic.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. import torch
  2. import torch.nn as nn
  3. from typing import List
  4. # --------------------- Basic modules ---------------------
  5. def get_conv2d(c1, c2, k, p, s, d=1, g=1, bias=False):
  6. conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
  7. return conv
  8. def get_activation(act_type=None):
  9. if act_type == 'relu':
  10. return nn.ReLU(inplace=True)
  11. elif act_type == 'lrelu':
  12. return nn.LeakyReLU(0.1, inplace=True)
  13. elif act_type == 'mish':
  14. return nn.Mish(inplace=True)
  15. elif act_type == 'silu':
  16. return nn.SiLU(inplace=True)
  17. elif act_type is None:
  18. return nn.Identity()
  19. else:
  20. raise NotImplementedError
  21. def get_norm(norm_type, dim):
  22. if norm_type == 'bn':
  23. return nn.BatchNorm2d(dim)
  24. elif norm_type == 'gn':
  25. return nn.GroupNorm(num_groups=32, num_channels=dim)
  26. elif norm_type is None:
  27. return nn.Identity()
  28. else:
  29. raise NotImplementedError
  30. class BasicConv(nn.Module):
  31. def __init__(self,
  32. in_dim, # in channels
  33. out_dim, # out channels
  34. kernel_size=1, # kernel size
  35. padding=0, # padding
  36. stride=1, # padding
  37. dilation=1, # dilation
  38. groups=1, # group
  39. act_type :str = 'lrelu', # activation
  40. norm_type :str = 'bn', # normalization
  41. depthwise :bool = False
  42. ):
  43. super(BasicConv, self).__init__()
  44. self.depthwise = depthwise
  45. use_bias = False if norm_type is not None else True
  46. if not depthwise:
  47. self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=groups, bias=use_bias)
  48. self.norm = get_norm(norm_type, out_dim)
  49. else:
  50. self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim, bias=use_bias)
  51. self.norm1 = get_norm(norm_type, in_dim)
  52. self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1, bias=use_bias)
  53. self.norm2 = get_norm(norm_type, out_dim)
  54. self.act = get_activation(act_type)
  55. def forward(self, x):
  56. if not self.depthwise:
  57. return self.act(self.norm(self.conv(x)))
  58. else:
  59. # Depthwise conv
  60. x = self.act(self.norm1(self.conv1(x)))
  61. # Pointwise conv
  62. x = self.act(self.norm2(self.conv2(x)))
  63. return x
  64. class DWConv(nn.Module):
  65. def __init__(self,
  66. in_dim :int, # in channels
  67. out_dim :int, # out channels
  68. kernel_size :int = 1, # kernel size
  69. padding :int = 0, # padding
  70. stride :int = 1, # padding
  71. dilation :int = 1, # dilation
  72. act_type :str = 'lrelu', # activation
  73. norm_type :str = 'BN', # normalization
  74. ):
  75. super(DWConv, self).__init__()
  76. assert in_dim == out_dim
  77. use_bias = False if norm_type is not None else True
  78. self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=out_dim, bias=use_bias)
  79. self.norm = get_norm(norm_type, out_dim)
  80. self.act = get_activation(act_type)
  81. def forward(self, x):
  82. return self.act(self.norm(self.conv(x)))
  83. # --------------------- Downsample modules ---------------------
  84. class ADown(nn.Module):
  85. def __init__(self,
  86. in_dim :int,
  87. out_dim :int,
  88. act_type :str = "silu",
  89. norm_type :str = "bn",
  90. depthwise :bool = False):
  91. super().__init__()
  92. inter_dim = out_dim // 2
  93. self.conv_layer_1 = BasicConv(in_dim // 2, inter_dim, kernel_size=3, padding=1, stride=2,
  94. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  95. self.conv_layer_2 = BasicConv(in_dim // 2, inter_dim, kernel_size=1,
  96. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  97. def forward(self, x):
  98. # Split
  99. x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
  100. x1,x2 = x.chunk(2, 1)
  101. # Downsample branch - 1
  102. x1 = self.conv_layer_1(x1)
  103. # Downsample branch - 2
  104. x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
  105. x2 = self.conv_layer_2(x2)
  106. return torch.cat([x1, x2], dim=1)
  107. class MDown(nn.Module):
  108. def __init__(self,
  109. in_dim :int,
  110. out_dim :int,
  111. act_type :str = 'silu',
  112. norm_type :str = 'BN',
  113. depthwise :bool = False,
  114. ) -> None:
  115. super().__init__()
  116. inter_dim = out_dim // 2
  117. self.downsample_1 = nn.Sequential(
  118. nn.MaxPool2d((2, 2), stride=2),
  119. BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  120. )
  121. self.downsample_2 = nn.Sequential(
  122. BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type),
  123. BasicConv(inter_dim, inter_dim,
  124. kernel_size=3, padding=1, stride=2,
  125. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  126. )
  127. def forward(self, x):
  128. x1 = self.downsample_1(x)
  129. x2 = self.downsample_2(x)
  130. return torch.cat([x1, x2], dim=1)
  131. # --------------------- Feature processing modules ---------------------
  132. class MBottleneck(nn.Module):
  133. def __init__(self,
  134. in_dim :int,
  135. out_dim :int,
  136. expansion :float = 0.5,
  137. shortcut :bool = False,
  138. act_type :str = 'silu',
  139. norm_type :str = 'bn',
  140. depthwise :bool = False,
  141. ) -> None:
  142. super(MBottleneck, self).__init__()
  143. inter_dim = int(out_dim * expansion)
  144. # ----------------- Network setting -----------------
  145. self.conv_layer = nn.Sequential(
  146. # 3x3 conv + bn + silu
  147. BasicConv(in_dim, inter_dim, kernel_size=3, padding=1, stride=1,
  148. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  149. # 5x5 dw conv
  150. DWConv(inter_dim, inter_dim, kernel_size=5, padding=2, stride=1,
  151. act_type=None, norm_type=norm_type),
  152. # 3x3 conv + bn + silu
  153. BasicConv(inter_dim, out_dim, kernel_size=3, padding=1, stride=1,
  154. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  155. )
  156. self.shortcut = shortcut and in_dim == out_dim
  157. def forward(self, x):
  158. h = self.conv_layer(x)
  159. return x + h if self.shortcut else h
  160. class CSPLayer(nn.Module):
  161. # CSP Bottleneck
  162. def __init__(self,
  163. in_dim :int,
  164. out_dim :int,
  165. num_blocks :int = 1,
  166. expansion :float = 0.5,
  167. shortcut :bool = True,
  168. act_type :str = 'silu',
  169. norm_type :str = 'bn',
  170. depthwise :bool = False,
  171. ) -> None:
  172. super().__init__()
  173. inter_dim = round(out_dim * expansion)
  174. self.input_proj = BasicConv(in_dim, out_dim, kernel_size=1, act_type=None, norm_type=norm_type, depthwise=depthwise)
  175. self.module = nn.Sequential(*[MBottleneck(inter_dim,
  176. inter_dim,
  177. expansion = 1.0,
  178. shortcut = shortcut,
  179. act_type = act_type,
  180. norm_type = norm_type,
  181. depthwise = depthwise,
  182. ) for _ in range(num_blocks)])
  183. def forward(self, x):
  184. # Split
  185. x1, x2 = torch.chunk(self.input_proj(x), chunks=2, dim=1)
  186. # Branch
  187. x2 = self.module(x2)
  188. # Output proj
  189. out = torch.cat([x1, x2], dim=1)
  190. return out
  191. class ElanLayer(nn.Module):
  192. def __init__(self,
  193. in_dim,
  194. out_dim,
  195. expansion :float = 0.5,
  196. num_blocks :int = 1,
  197. shortcut :bool = False,
  198. act_type :str = 'silu',
  199. norm_type :str = 'bn',
  200. depthwise :bool = False,
  201. ) -> None:
  202. super(ElanLayer, self).__init__()
  203. inter_dim = round(out_dim * expansion)
  204. self.input_proj = BasicConv(in_dim, inter_dim * 2, kernel_size=1, act_type=act_type, norm_type=norm_type)
  205. self.output_proj = BasicConv((2 + num_blocks) * inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  206. self.module = nn.ModuleList([MBottleneck(inter_dim,
  207. inter_dim,
  208. expansion = 1.0,
  209. shortcut = shortcut,
  210. act_type = act_type,
  211. norm_type = norm_type,
  212. depthwise = depthwise)
  213. for _ in range(num_blocks)])
  214. def forward(self, x):
  215. # Input proj
  216. x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
  217. out = list([x1, x2])
  218. # Bottleneck
  219. out.extend(m(out[-1]) for m in self.module)
  220. # Output proj
  221. out = self.output_proj(torch.cat(out, dim=1))
  222. return out
  223. class GElanLayer(nn.Module):
  224. """Modified YOLOv9's GELAN module"""
  225. def __init__(self,
  226. in_dim :int,
  227. inter_dims :List,
  228. out_dim :int,
  229. num_blocks :int = 1,
  230. shortcut :bool = False,
  231. act_type :str = 'silu',
  232. norm_type :str = 'bn',
  233. depthwise :bool = False,
  234. ) -> None:
  235. super(GElanLayer, self).__init__()
  236. # ----------- Basic parameters -----------
  237. self.in_dim = in_dim
  238. self.inter_dims = inter_dims
  239. self.out_dim = out_dim
  240. # ----------- Network parameters -----------
  241. self.conv_layer_1 = BasicConv(in_dim, inter_dims[0], kernel_size=1, act_type=act_type, norm_type=norm_type)
  242. self.elan_module_1 = nn.Sequential(
  243. CSPLayer(inter_dims[0]//2,
  244. inter_dims[1],
  245. num_blocks = num_blocks,
  246. shortcut = shortcut,
  247. expansion = 0.5,
  248. act_type = act_type,
  249. norm_type = norm_type,
  250. depthwise = depthwise),
  251. BasicConv(inter_dims[1], inter_dims[1], kernel_size=3, padding=1,
  252. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  253. )
  254. self.elan_module_2 = nn.Sequential(
  255. CSPLayer(inter_dims[1],
  256. inter_dims[1],
  257. num_blocks = num_blocks,
  258. shortcut = shortcut,
  259. expansion = 0.5,
  260. act_type = act_type,
  261. norm_type = norm_type,
  262. depthwise = depthwise),
  263. BasicConv(inter_dims[1], inter_dims[1], kernel_size=3, padding=1,
  264. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  265. )
  266. self.conv_layer_2 = BasicConv(inter_dims[0] + 2*self.inter_dims[1], out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  267. def forward(self, x):
  268. # Input proj
  269. x1, x2 = torch.chunk(self.conv_layer_1(x), 2, dim=1)
  270. out = list([x1, x2])
  271. # ELAN module
  272. out.append(self.elan_module_1(out[-1]))
  273. out.append(self.elan_module_2(out[-1]))
  274. # Output proj
  275. out = self.conv_layer_2(torch.cat(out, dim=1))
  276. return out