basic.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. import torch
  2. import torch.nn as nn
  3. # ----------------- MLP modules -----------------
  4. class MLP(nn.Module):
  5. def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
  6. super().__init__()
  7. self.num_layers = num_layers
  8. h = [hidden_dim] * (num_layers - 1)
  9. self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([in_dim] + h, h + [out_dim]))
  10. def forward(self, x):
  11. for i, layer in enumerate(self.layers):
  12. x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  13. return x
  14. class FFN(nn.Module):
  15. def __init__(self, d_model=256, mlp_ratio=4.0, dropout=0., act_type='relu'):
  16. super().__init__()
  17. self.fpn_dim = round(d_model * mlp_ratio)
  18. self.linear1 = nn.Linear(d_model, self.fpn_dim)
  19. self.activation = get_activation(act_type)
  20. self.dropout2 = nn.Dropout(dropout)
  21. self.linear2 = nn.Linear(self.fpn_dim, d_model)
  22. self.dropout3 = nn.Dropout(dropout)
  23. self.norm = nn.LayerNorm(d_model)
  24. def forward(self, src):
  25. src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
  26. src = src + self.dropout3(src2)
  27. src = self.norm(src)
  28. return src
  29. # ----------------- Basic CNN Ops -----------------
  30. def get_conv2d(c1, c2, k, p, s, g, bias=False):
  31. conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, groups=g, bias=bias)
  32. return conv
  33. def get_activation(act_type=None):
  34. if act_type == 'relu':
  35. return nn.ReLU(inplace=True)
  36. elif act_type == 'lrelu':
  37. return nn.LeakyReLU(0.1, inplace=True)
  38. elif act_type == 'mish':
  39. return nn.Mish(inplace=True)
  40. elif act_type == 'silu':
  41. return nn.SiLU(inplace=True)
  42. elif act_type == 'gelu':
  43. return nn.GELU()
  44. elif act_type is None:
  45. return nn.Identity()
  46. else:
  47. raise NotImplementedError
  48. def get_norm(norm_type, dim):
  49. if norm_type == 'BN':
  50. return nn.BatchNorm2d(dim)
  51. elif norm_type == 'GN':
  52. return nn.GroupNorm(num_groups=32, num_channels=dim)
  53. elif norm_type is None:
  54. return nn.Identity()
  55. else:
  56. raise NotImplementedError
  57. def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
  58. """3x3 convolution with padding"""
  59. return nn.Conv2d(
  60. in_planes,
  61. out_planes,
  62. kernel_size=3,
  63. stride=stride,
  64. padding=dilation,
  65. groups=groups,
  66. bias=False,
  67. dilation=dilation,
  68. )
  69. def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
  70. """1x1 convolution"""
  71. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
  72. class FrozenBatchNorm2d(torch.nn.Module):
  73. def __init__(self, n):
  74. super(FrozenBatchNorm2d, self).__init__()
  75. self.register_buffer("weight", torch.ones(n))
  76. self.register_buffer("bias", torch.zeros(n))
  77. self.register_buffer("running_mean", torch.zeros(n))
  78. self.register_buffer("running_var", torch.ones(n))
  79. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  80. missing_keys, unexpected_keys, error_msgs):
  81. num_batches_tracked_key = prefix + 'num_batches_tracked'
  82. if num_batches_tracked_key in state_dict:
  83. del state_dict[num_batches_tracked_key]
  84. super(FrozenBatchNorm2d, self)._load_from_state_dict(
  85. state_dict, prefix, local_metadata, strict,
  86. missing_keys, unexpected_keys, error_msgs)
  87. def forward(self, x):
  88. # move reshapes to the beginning
  89. # to make it fuser-friendly
  90. w = self.weight.reshape(1, -1, 1, 1)
  91. b = self.bias.reshape(1, -1, 1, 1)
  92. rv = self.running_var.reshape(1, -1, 1, 1)
  93. rm = self.running_mean.reshape(1, -1, 1, 1)
  94. eps = 1e-5
  95. scale = w * (rv + eps).rsqrt()
  96. bias = b - rm * scale
  97. return x * scale + bias
  98. class BasicConv(nn.Module):
  99. def __init__(self,
  100. in_dim, # in channels
  101. out_dim, # out channels
  102. kernel_size=1, # kernel size
  103. padding=0, # padding
  104. stride=1, # padding
  105. act_type :str = 'lrelu', # activation
  106. norm_type :str = 'BN', # normalization
  107. ):
  108. super(BasicConv, self).__init__()
  109. add_bias = False if norm_type else True
  110. self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
  111. self.norm = get_norm(norm_type, out_dim)
  112. self.act = get_activation(act_type)
  113. def forward(self, x):
  114. return self.act(self.norm(self.conv(x)))
  115. class DepthwiseConv(nn.Module):
  116. def __init__(self,
  117. in_dim, # in channels
  118. out_dim, # out channels
  119. kernel_size=1, # kernel size
  120. padding=0, # padding
  121. stride=1, # padding
  122. act_type :str = None, # activation
  123. norm_type :str = 'BN', # normalization
  124. ):
  125. super(DepthwiseConv, self).__init__()
  126. assert in_dim == out_dim
  127. add_bias = False if norm_type else True
  128. self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=out_dim, bias=add_bias)
  129. self.norm = get_norm(norm_type, out_dim)
  130. self.act = get_activation(act_type)
  131. def forward(self, x):
  132. return self.act(self.norm(self.conv(x)))
  133. class PointwiseConv(nn.Module):
  134. def __init__(self,
  135. in_dim, # in channels
  136. out_dim, # out channels
  137. act_type :str = 'lrelu', # activation
  138. norm_type :str = 'BN', # normalization
  139. ):
  140. super(DepthwiseConv, self).__init__()
  141. assert in_dim == out_dim
  142. add_bias = False if norm_type else True
  143. self.conv = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, g=1, bias=add_bias)
  144. self.norm = get_norm(norm_type, out_dim)
  145. self.act = get_activation(act_type)
  146. def forward(self, x):
  147. return self.act(self.norm(self.conv(x)))
  148. # ----------------- CNN Modules -----------------
  149. class Bottleneck(nn.Module):
  150. def __init__(self,
  151. in_dim,
  152. out_dim,
  153. expand_ratio = 0.5,
  154. kernel_sizes = [3, 3],
  155. shortcut = True,
  156. act_type = 'silu',
  157. norm_type = 'BN',
  158. depthwise = False,):
  159. super(Bottleneck, self).__init__()
  160. inter_dim = int(out_dim * expand_ratio)
  161. if depthwise:
  162. self.cv1 = nn.Sequential(
  163. DepthwiseConv(in_dim, in_dim, kernel_size=kernel_sizes[0], padding=kernel_sizes[0]//2, act_type=act_type, norm_type=norm_type),
  164. PointwiseConv(in_dim, inter_dim, act_type=act_type, norm_type=norm_type),
  165. )
  166. self.cv2 = nn.Sequential(
  167. DepthwiseConv(inter_dim, inter_dim, kernel_size=kernel_sizes[1], padding=kernel_sizes[1]//2, act_type=act_type, norm_type=norm_type),
  168. PointwiseConv(inter_dim, out_dim, act_type=act_type, norm_type=norm_type),
  169. )
  170. else:
  171. self.cv1 = BasicConv(in_dim, inter_dim, kernel_size=kernel_sizes[0], padding=kernel_sizes[0]//2, act_type=act_type, norm_type=norm_type)
  172. self.cv2 = BasicConv(inter_dim, out_dim, kernel_size=kernel_sizes[1], padding=kernel_sizes[1]//2, act_type=act_type, norm_type=norm_type)
  173. self.shortcut = shortcut and in_dim == out_dim
  174. def forward(self, x):
  175. h = self.cv2(self.cv1(x))
  176. return x + h if self.shortcut else h
  177. class RTCBlock(nn.Module):
  178. def __init__(self,
  179. in_dim,
  180. out_dim,
  181. num_blocks = 1,
  182. shortcut = False,
  183. act_type = 'silu',
  184. norm_type = 'BN',
  185. depthwise = False,):
  186. super(RTCBlock, self).__init__()
  187. self.inter_dim = out_dim // 2
  188. self.input_proj = BasicConv(in_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  189. self.m = nn.Sequential(*(
  190. Bottleneck(self.inter_dim, self.inter_dim, 1.0, [3, 3], shortcut, act_type, norm_type, depthwise)
  191. for _ in range(num_blocks)))
  192. self.output_proj = BasicConv((2 + num_blocks) * self.inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  193. def forward(self, x):
  194. # Input proj
  195. x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
  196. out = list([x1, x2])
  197. # Bottlenecl
  198. out.extend(m(out[-1]) for m in self.m)
  199. # Output proj
  200. out = self.output_proj(torch.cat(out, dim=1))
  201. return out
  202. class RepVggBlock(nn.Module):
  203. def __init__(self, in_dim, out_dim, act_type='relu', norm_type='BN', alpha=False):
  204. super(RepVggBlock, self).__init__()
  205. self.in_dim = in_dim
  206. self.out_dim = out_dim
  207. self.conv1 = BasicConv(in_dim, out_dim, kernel_size=3, padding=1, act_type=None, norm_type=norm_type)
  208. self.conv2 = BasicConv(in_dim, out_dim, kernel_size=3, padding=1, act_type=None, norm_type=norm_type)
  209. self.act = get_activation(act_type)
  210. if alpha:
  211. self.alpha = nn.Parameter(torch.as_tensor([1.0]).float())
  212. else:
  213. self.alpha = None
  214. def forward(self, x):
  215. if hasattr(self, 'conv'):
  216. y = self.conv(x)
  217. else:
  218. if self.alpha:
  219. y = self.conv1(x) + self.alpha * self.conv2(x)
  220. else:
  221. y = self.conv1(x) + self.conv2(x)
  222. y = self.act(y)
  223. return y
  224. def convert_to_deploy(self):
  225. if not hasattr(self, 'conv'):
  226. self.conv = nn.Conv2d(
  227. self.in_dim,
  228. self.out_dim,
  229. kernel_size=3,
  230. stride=1,
  231. padding=1,
  232. groups=1)
  233. kernel, bias = self.get_equivalent_kernel_bias()
  234. # self.conv.weight.set_value(kernel)
  235. # self.conv.bias.set_value(bias)
  236. self.conv.weight.data = kernel
  237. self.conv.bias.data = bias
  238. self.__delattr__('conv1')
  239. self.__delattr__('conv2')
  240. def get_equivalent_kernel_bias(self):
  241. kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
  242. kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
  243. if self.alpha:
  244. return kernel3x3 + self.alpha * self._pad_1x1_to_3x3_tensor(
  245. kernel1x1), bias3x3 + self.alpha * bias1x1
  246. else:
  247. return kernel3x3 + self._pad_1x1_to_3x3_tensor(
  248. kernel1x1), bias3x3 + bias1x1
  249. def _pad_1x1_to_3x3_tensor(self, kernel1x1):
  250. if kernel1x1 is None:
  251. return 0
  252. else:
  253. return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
  254. def _fuse_bn_tensor(self, branch):
  255. if branch is None:
  256. return 0, 0
  257. kernel = branch.conv.weight
  258. running_mean = branch.bn._mean
  259. running_var = branch.bn._variance
  260. gamma = branch.bn.weight
  261. beta = branch.bn.bias
  262. eps = branch.bn._epsilon
  263. std = (running_var + eps).sqrt()
  264. t = (gamma / std).reshape((-1, 1, 1, 1))
  265. return kernel * t, beta - running_mean * gamma / std
  266. class CSPRepLayer(nn.Module):
  267. def __init__(self,
  268. in_dim :int,
  269. out_dim :int,
  270. num_blocks :int = 3,
  271. expansion :float = 1.0,
  272. act_type :str ="silu",
  273. norm_type :str = 'BN'):
  274. super(CSPRepLayer, self).__init__()
  275. hidden_dim = int(out_dim * expansion)
  276. self.conv1 = BasicConv(
  277. in_dim, hidden_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  278. self.conv2 = BasicConv(
  279. in_dim, hidden_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  280. self.bottlenecks = nn.Sequential(*[
  281. RepVggBlock(
  282. hidden_dim, hidden_dim, act_type=act_type, norm_type=norm_type)
  283. for _ in range(num_blocks)
  284. ])
  285. if hidden_dim != out_dim:
  286. self.conv3 = BasicConv(hidden_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  287. else:
  288. self.conv3 = nn.Identity()
  289. def forward(self, x):
  290. x_1 = self.conv1(x)
  291. x_1 = self.bottlenecks(x_1)
  292. x_2 = self.conv2(x)
  293. return self.conv3(x_1 + x_2)