basic.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. import math
  2. import copy
  3. import torch
  4. import torch.nn as nn
  5. def get_clones(module, N):
  6. if N <= 0:
  7. return None
  8. else:
  9. return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
  10. # ----------------- MLP modules -----------------
  11. class MLP(nn.Module):
  12. def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
  13. super().__init__()
  14. self.num_layers = num_layers
  15. h = [hidden_dim] * (num_layers - 1)
  16. self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([in_dim] + h, h + [out_dim]))
  17. def forward(self, x):
  18. for i, layer in enumerate(self.layers):
  19. x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  20. return x
  21. class FFN(nn.Module):
  22. def __init__(self, d_model=256, mlp_ratio=4.0, dropout=0., act_type='relu'):
  23. super().__init__()
  24. self.fpn_dim = round(d_model * mlp_ratio)
  25. self.linear1 = nn.Linear(d_model, self.fpn_dim)
  26. self.activation = get_activation(act_type)
  27. self.dropout2 = nn.Dropout(dropout)
  28. self.linear2 = nn.Linear(self.fpn_dim, d_model)
  29. self.dropout3 = nn.Dropout(dropout)
  30. self.norm = nn.LayerNorm(d_model)
  31. def forward(self, src):
  32. src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
  33. src = src + self.dropout3(src2)
  34. src = self.norm(src)
  35. return src
  36. # ----------------- CNN modules -----------------
  37. def get_conv2d(c1, c2, k, p, s, g, bias=False):
  38. conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, groups=g, bias=bias)
  39. return conv
  40. def get_activation(act_type=None):
  41. if act_type == 'relu':
  42. return nn.ReLU(inplace=True)
  43. elif act_type == 'lrelu':
  44. return nn.LeakyReLU(0.1, inplace=True)
  45. elif act_type == 'mish':
  46. return nn.Mish(inplace=True)
  47. elif act_type == 'silu':
  48. return nn.SiLU(inplace=True)
  49. elif act_type == 'gelu':
  50. return nn.GELU()
  51. elif act_type is None:
  52. return nn.Identity()
  53. else:
  54. raise NotImplementedError
  55. def get_norm(norm_type, dim):
  56. if norm_type == 'BN':
  57. return nn.BatchNorm2d(dim)
  58. elif norm_type == 'GN':
  59. return nn.GroupNorm(num_groups=32, num_channels=dim)
  60. elif norm_type is None:
  61. return nn.Identity()
  62. else:
  63. raise NotImplementedError
  64. def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
  65. """3x3 convolution with padding"""
  66. return nn.Conv2d(
  67. in_planes,
  68. out_planes,
  69. kernel_size=3,
  70. stride=stride,
  71. padding=dilation,
  72. groups=groups,
  73. bias=False,
  74. dilation=dilation,
  75. )
  76. def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
  77. """1x1 convolution"""
  78. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
  79. class FrozenBatchNorm2d(torch.nn.Module):
  80. def __init__(self, n):
  81. super(FrozenBatchNorm2d, self).__init__()
  82. self.register_buffer("weight", torch.ones(n))
  83. self.register_buffer("bias", torch.zeros(n))
  84. self.register_buffer("running_mean", torch.zeros(n))
  85. self.register_buffer("running_var", torch.ones(n))
  86. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  87. missing_keys, unexpected_keys, error_msgs):
  88. num_batches_tracked_key = prefix + 'num_batches_tracked'
  89. if num_batches_tracked_key in state_dict:
  90. del state_dict[num_batches_tracked_key]
  91. super(FrozenBatchNorm2d, self)._load_from_state_dict(
  92. state_dict, prefix, local_metadata, strict,
  93. missing_keys, unexpected_keys, error_msgs)
  94. def forward(self, x):
  95. # move reshapes to the beginning
  96. # to make it fuser-friendly
  97. w = self.weight.reshape(1, -1, 1, 1)
  98. b = self.bias.reshape(1, -1, 1, 1)
  99. rv = self.running_var.reshape(1, -1, 1, 1)
  100. rm = self.running_mean.reshape(1, -1, 1, 1)
  101. eps = 1e-5
  102. scale = w * (rv + eps).rsqrt()
  103. bias = b - rm * scale
  104. return x * scale + bias
  105. class BasicConv(nn.Module):
  106. def __init__(self,
  107. in_dim, # in channels
  108. out_dim, # out channels
  109. kernel_size=1, # kernel size
  110. padding=0, # padding
  111. stride=1, # padding
  112. act_type :str = 'lrelu', # activation
  113. norm_type :str = 'BN', # normalization
  114. ):
  115. super(BasicConv, self).__init__()
  116. add_bias = False if norm_type else True
  117. self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
  118. self.norm = get_norm(norm_type, out_dim)
  119. self.act = get_activation(act_type)
  120. def forward(self, x):
  121. return self.act(self.norm(self.conv(x)))
  122. class DepthwiseConv(nn.Module):
  123. def __init__(self,
  124. in_dim, # in channels
  125. out_dim, # out channels
  126. kernel_size=1, # kernel size
  127. padding=0, # padding
  128. stride=1, # padding
  129. act_type :str = None, # activation
  130. norm_type :str = 'BN', # normalization
  131. ):
  132. super(DepthwiseConv, self).__init__()
  133. assert in_dim == out_dim
  134. add_bias = False if norm_type else True
  135. self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=out_dim, bias=add_bias)
  136. self.norm = get_norm(norm_type, out_dim)
  137. self.act = get_activation(act_type)
  138. def forward(self, x):
  139. return self.act(self.norm(self.conv(x)))
  140. class PointwiseConv(nn.Module):
  141. def __init__(self,
  142. in_dim, # in channels
  143. out_dim, # out channels
  144. act_type :str = 'lrelu', # activation
  145. norm_type :str = 'BN', # normalization
  146. ):
  147. super(DepthwiseConv, self).__init__()
  148. assert in_dim == out_dim
  149. add_bias = False if norm_type else True
  150. self.conv = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, g=1, bias=add_bias)
  151. self.norm = get_norm(norm_type, out_dim)
  152. self.act = get_activation(act_type)
  153. def forward(self, x):
  154. return self.act(self.norm(self.conv(x)))
  155. ## Yolov8's BottleNeck
  156. class Bottleneck(nn.Module):
  157. def __init__(self,
  158. in_dim,
  159. out_dim,
  160. expand_ratio = 0.5,
  161. kernel_sizes = [3, 3],
  162. shortcut = True,
  163. act_type = 'silu',
  164. norm_type = 'BN',
  165. depthwise = False,):
  166. super(Bottleneck, self).__init__()
  167. inter_dim = int(out_dim * expand_ratio)
  168. if depthwise:
  169. self.cv1 = nn.Sequential(
  170. DepthwiseConv(in_dim, in_dim, kernel_size=kernel_sizes[0], padding=kernel_sizes[0]//2, act_type=act_type, norm_type=norm_type),
  171. PointwiseConv(in_dim, inter_dim, act_type=act_type, norm_type=norm_type),
  172. )
  173. self.cv2 = nn.Sequential(
  174. DepthwiseConv(inter_dim, inter_dim, kernel_size=kernel_sizes[1], padding=kernel_sizes[1]//2, act_type=act_type, norm_type=norm_type),
  175. PointwiseConv(inter_dim, out_dim, act_type=act_type, norm_type=norm_type),
  176. )
  177. else:
  178. self.cv1 = BasicConv(in_dim, inter_dim, kernel_size=kernel_sizes[0], padding=kernel_sizes[0]//2, act_type=act_type, norm_type=norm_type)
  179. self.cv2 = BasicConv(inter_dim, out_dim, kernel_size=kernel_sizes[1], padding=kernel_sizes[1]//2, act_type=act_type, norm_type=norm_type)
  180. self.shortcut = shortcut and in_dim == out_dim
  181. def forward(self, x):
  182. h = self.cv2(self.cv1(x))
  183. return x + h if self.shortcut else h
  184. # Yolov8's StageBlock
  185. class RTCBlock(nn.Module):
  186. def __init__(self,
  187. in_dim,
  188. out_dim,
  189. num_blocks = 1,
  190. shortcut = False,
  191. act_type = 'silu',
  192. norm_type = 'BN',
  193. depthwise = False,):
  194. super(RTCBlock, self).__init__()
  195. self.inter_dim = out_dim // 2
  196. self.input_proj = BasicConv(in_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  197. self.m = nn.Sequential(*(
  198. Bottleneck(self.inter_dim, self.inter_dim, 1.0, [3, 3], shortcut, act_type, norm_type, depthwise)
  199. for _ in range(num_blocks)))
  200. self.output_proj = BasicConv((2 + num_blocks) * self.inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  201. def forward(self, x):
  202. # Input proj
  203. x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
  204. out = list([x1, x2])
  205. # Bottlenecl
  206. out.extend(m(out[-1]) for m in self.m)
  207. # Output proj
  208. out = self.output_proj(torch.cat(out, dim=1))
  209. return out
  210. # ----------------- Transformer modules -----------------
  211. ## Transformer layer
  212. class TransformerLayer(nn.Module):
  213. def __init__(self,
  214. d_model :int = 256,
  215. num_heads :int = 8,
  216. mlp_ratio :float = 4.0,
  217. dropout :float = 0.1,
  218. act_type :str = "relu",
  219. ):
  220. super().__init__()
  221. # ----------- Basic parameters -----------
  222. self.d_model = d_model
  223. self.num_heads = num_heads
  224. self.mlp_ratio = mlp_ratio
  225. self.dropout = dropout
  226. self.act_type = act_type
  227. # ----------- Basic parameters -----------
  228. # Multi-head Self-Attn
  229. self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
  230. self.dropout = nn.Dropout(dropout)
  231. self.norm = nn.LayerNorm(d_model)
  232. # Feedforwaed Network
  233. self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
  234. def with_pos_embed(self, tensor, pos):
  235. return tensor if pos is None else tensor + pos
  236. def forward(self, src, pos_embed):
  237. """
  238. Input:
  239. src: [torch.Tensor] -> [B, N, C]
  240. pos_embed: [torch.Tensor] -> [B, N, C]
  241. Output:
  242. src: [torch.Tensor] -> [B, N, C]
  243. """
  244. q = k = self.with_pos_embed(src, pos_embed)
  245. # -------------- MHSA --------------
  246. src2 = self.self_attn(q, k, value=src)[0]
  247. src = src + self.dropout(src2)
  248. src = self.norm(src)
  249. # -------------- FFN --------------
  250. src = self.ffn(src)
  251. return src