basic.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657
  1. import math
  2. import copy
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. def get_clones(module, N):
  7. if N <= 0:
  8. return None
  9. else:
  10. return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
  11. # ----------------- MLP modules -----------------
  12. class MLP(nn.Module):
  13. def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
  14. super().__init__()
  15. self.num_layers = num_layers
  16. h = [hidden_dim] * (num_layers - 1)
  17. self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([in_dim] + h, h + [out_dim]))
  18. def forward(self, x):
  19. for i, layer in enumerate(self.layers):
  20. x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  21. return x
  22. class FFN(nn.Module):
  23. def __init__(self, d_model=256, mlp_ratio=4.0, dropout=0., act_type='relu'):
  24. super().__init__()
  25. self.fpn_dim = round(d_model * mlp_ratio)
  26. self.linear1 = nn.Linear(d_model, self.fpn_dim)
  27. self.activation = get_activation(act_type)
  28. self.dropout2 = nn.Dropout(dropout)
  29. self.linear2 = nn.Linear(self.fpn_dim, d_model)
  30. self.dropout3 = nn.Dropout(dropout)
  31. self.norm = nn.LayerNorm(d_model)
  32. def forward(self, src):
  33. src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
  34. src = src + self.dropout3(src2)
  35. src = self.norm(src)
  36. return src
  37. # ----------------- CNN modules -----------------
  38. def get_conv2d(c1, c2, k, p, s, g, bias=False):
  39. conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, groups=g, bias=bias)
  40. return conv
  41. def get_activation(act_type=None):
  42. if act_type == 'relu':
  43. return nn.ReLU(inplace=True)
  44. elif act_type == 'lrelu':
  45. return nn.LeakyReLU(0.1, inplace=True)
  46. elif act_type == 'mish':
  47. return nn.Mish(inplace=True)
  48. elif act_type == 'silu':
  49. return nn.SiLU(inplace=True)
  50. elif act_type == 'gelu':
  51. return nn.GELU()
  52. elif act_type is None:
  53. return nn.Identity()
  54. else:
  55. raise NotImplementedError
  56. def get_norm(norm_type, dim):
  57. if norm_type == 'BN':
  58. return nn.BatchNorm2d(dim)
  59. elif norm_type == 'GN':
  60. return nn.GroupNorm(num_groups=32, num_channels=dim)
  61. elif norm_type is None:
  62. return nn.Identity()
  63. else:
  64. raise NotImplementedError
  65. def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
  66. """3x3 convolution with padding"""
  67. return nn.Conv2d(
  68. in_planes,
  69. out_planes,
  70. kernel_size=3,
  71. stride=stride,
  72. padding=dilation,
  73. groups=groups,
  74. bias=False,
  75. dilation=dilation,
  76. )
  77. def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
  78. """1x1 convolution"""
  79. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
  80. class FrozenBatchNorm2d(torch.nn.Module):
  81. def __init__(self, n):
  82. super(FrozenBatchNorm2d, self).__init__()
  83. self.register_buffer("weight", torch.ones(n))
  84. self.register_buffer("bias", torch.zeros(n))
  85. self.register_buffer("running_mean", torch.zeros(n))
  86. self.register_buffer("running_var", torch.ones(n))
  87. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  88. missing_keys, unexpected_keys, error_msgs):
  89. num_batches_tracked_key = prefix + 'num_batches_tracked'
  90. if num_batches_tracked_key in state_dict:
  91. del state_dict[num_batches_tracked_key]
  92. super(FrozenBatchNorm2d, self)._load_from_state_dict(
  93. state_dict, prefix, local_metadata, strict,
  94. missing_keys, unexpected_keys, error_msgs)
  95. def forward(self, x):
  96. # move reshapes to the beginning
  97. # to make it fuser-friendly
  98. w = self.weight.reshape(1, -1, 1, 1)
  99. b = self.bias.reshape(1, -1, 1, 1)
  100. rv = self.running_var.reshape(1, -1, 1, 1)
  101. rm = self.running_mean.reshape(1, -1, 1, 1)
  102. eps = 1e-5
  103. scale = w * (rv + eps).rsqrt()
  104. bias = b - rm * scale
  105. return x * scale + bias
  106. class BasicConv(nn.Module):
  107. def __init__(self,
  108. in_dim, # in channels
  109. out_dim, # out channels
  110. kernel_size=1, # kernel size
  111. padding=0, # padding
  112. stride=1, # padding
  113. act_type :str = 'lrelu', # activation
  114. norm_type :str = 'BN', # normalization
  115. ):
  116. super(BasicConv, self).__init__()
  117. add_bias = False if norm_type else True
  118. self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
  119. self.norm = get_norm(norm_type, out_dim)
  120. self.act = get_activation(act_type)
  121. def forward(self, x):
  122. return self.act(self.norm(self.conv(x)))
  123. class DepthwiseConv(nn.Module):
  124. def __init__(self,
  125. in_dim, # in channels
  126. out_dim, # out channels
  127. kernel_size=1, # kernel size
  128. padding=0, # padding
  129. stride=1, # padding
  130. act_type :str = None, # activation
  131. norm_type :str = 'BN', # normalization
  132. ):
  133. super(DepthwiseConv, self).__init__()
  134. assert in_dim == out_dim
  135. add_bias = False if norm_type else True
  136. self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=out_dim, bias=add_bias)
  137. self.norm = get_norm(norm_type, out_dim)
  138. self.act = get_activation(act_type)
  139. def forward(self, x):
  140. return self.act(self.norm(self.conv(x)))
  141. class PointwiseConv(nn.Module):
  142. def __init__(self,
  143. in_dim, # in channels
  144. out_dim, # out channels
  145. act_type :str = 'lrelu', # activation
  146. norm_type :str = 'BN', # normalization
  147. ):
  148. super(DepthwiseConv, self).__init__()
  149. assert in_dim == out_dim
  150. add_bias = False if norm_type else True
  151. self.conv = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, g=1, bias=add_bias)
  152. self.norm = get_norm(norm_type, out_dim)
  153. self.act = get_activation(act_type)
  154. def forward(self, x):
  155. return self.act(self.norm(self.conv(x)))
  156. ## Yolov8's BottleNeck
  157. class Bottleneck(nn.Module):
  158. def __init__(self,
  159. in_dim,
  160. out_dim,
  161. expand_ratio = 0.5,
  162. kernel_sizes = [3, 3],
  163. shortcut = True,
  164. act_type = 'silu',
  165. norm_type = 'BN',
  166. depthwise = False,):
  167. super(Bottleneck, self).__init__()
  168. inter_dim = int(out_dim * expand_ratio)
  169. if depthwise:
  170. self.cv1 = nn.Sequential(
  171. DepthwiseConv(in_dim, in_dim, kernel_size=kernel_sizes[0], padding=kernel_sizes[0]//2, act_type=act_type, norm_type=norm_type),
  172. PointwiseConv(in_dim, inter_dim, act_type=act_type, norm_type=norm_type),
  173. )
  174. self.cv2 = nn.Sequential(
  175. DepthwiseConv(inter_dim, inter_dim, kernel_size=kernel_sizes[1], padding=kernel_sizes[1]//2, act_type=act_type, norm_type=norm_type),
  176. PointwiseConv(inter_dim, out_dim, act_type=act_type, norm_type=norm_type),
  177. )
  178. else:
  179. self.cv1 = BasicConv(in_dim, inter_dim, kernel_size=kernel_sizes[0], padding=kernel_sizes[0]//2, act_type=act_type, norm_type=norm_type)
  180. self.cv2 = BasicConv(inter_dim, out_dim, kernel_size=kernel_sizes[1], padding=kernel_sizes[1]//2, act_type=act_type, norm_type=norm_type)
  181. self.shortcut = shortcut and in_dim == out_dim
  182. def forward(self, x):
  183. h = self.cv2(self.cv1(x))
  184. return x + h if self.shortcut else h
  185. # Yolov8's StageBlock
  186. class RTCBlock(nn.Module):
  187. def __init__(self,
  188. in_dim,
  189. out_dim,
  190. num_blocks = 1,
  191. shortcut = False,
  192. act_type = 'silu',
  193. norm_type = 'BN',
  194. depthwise = False,):
  195. super(RTCBlock, self).__init__()
  196. self.inter_dim = out_dim // 2
  197. self.input_proj = BasicConv(in_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  198. self.m = nn.Sequential(*(
  199. Bottleneck(self.inter_dim, self.inter_dim, 1.0, [3, 3], shortcut, act_type, norm_type, depthwise)
  200. for _ in range(num_blocks)))
  201. self.output_proj = BasicConv((2 + num_blocks) * self.inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  202. def forward(self, x):
  203. # Input proj
  204. x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
  205. out = list([x1, x2])
  206. # Bottlenecl
  207. out.extend(m(out[-1]) for m in self.m)
  208. # Output proj
  209. out = self.output_proj(torch.cat(out, dim=1))
  210. return out
  211. # ----------------- Transformer modules -----------------
  212. ## Basic ops of Deformable Attn
  213. def deformable_attention_core_func(value, value_spatial_shapes,
  214. value_level_start_index, sampling_locations,
  215. attention_weights):
  216. """
  217. Args:
  218. value (Tensor): [bs, value_length, n_head, c]
  219. value_spatial_shapes (Tensor|List): [n_levels, 2]
  220. value_level_start_index (Tensor|List): [n_levels]
  221. sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2]
  222. attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points]
  223. Returns:
  224. output (Tensor): [bs, Length_{query}, C]
  225. """
  226. bs, _, n_head, c = value.shape
  227. _, Len_q, _, n_levels, n_points, _ = sampling_locations.shape
  228. split_shape = [h * w for h, w in value_spatial_shapes]
  229. value_list = value.split(split_shape, axis=1)
  230. sampling_grids = 2 * sampling_locations - 1
  231. sampling_value_list = []
  232. for level, (h, w) in enumerate(value_spatial_shapes):
  233. # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
  234. value_l_ = value_list[level].flatten(2).transpose(
  235. [0, 2, 1]).reshape([bs * n_head, c, h, w])
  236. # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
  237. sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(
  238. [0, 2, 1, 3, 4]).flatten(0, 1)
  239. # N_*M_, D_, Lq_, P_
  240. sampling_value_l_ = F.grid_sample(
  241. value_l_,
  242. sampling_grid_l_,
  243. mode='bilinear',
  244. padding_mode='zeros',
  245. align_corners=False)
  246. sampling_value_list.append(sampling_value_l_)
  247. # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_*M_, 1, Lq_, L_*P_)
  248. attention_weights = attention_weights.transpose([0, 2, 1, 3, 4]).reshape(
  249. [bs * n_head, 1, Len_q, n_levels * n_points])
  250. output = (torch.stack(
  251. sampling_value_list, axis=-2).flatten(-2) *
  252. attention_weights).sum(-1).reshape([bs, n_head * c, Len_q])
  253. return output.transpose([0, 2, 1])
  254. class MSDeformableAttention(nn.Layer):
  255. def __init__(self,
  256. embed_dim=256,
  257. num_heads=8,
  258. num_levels=4,
  259. num_points=4,
  260. lr_mult=0.1):
  261. """
  262. Multi-Scale Deformable Attention Module
  263. """
  264. super(MSDeformableAttention, self).__init__()
  265. self.embed_dim = embed_dim
  266. self.num_heads = num_heads
  267. self.num_levels = num_levels
  268. self.num_points = num_points
  269. self.total_points = num_heads * num_levels * num_points
  270. self.head_dim = embed_dim // num_heads
  271. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  272. self.sampling_offsets = nn.Linear(
  273. embed_dim,
  274. self.total_points * 2,
  275. weight_attr=ParamAttr(learning_rate=lr_mult),
  276. bias_attr=ParamAttr(learning_rate=lr_mult))
  277. self.attention_weights = nn.Linear(embed_dim, self.total_points)
  278. self.value_proj = nn.Linear(embed_dim, embed_dim)
  279. self.output_proj = nn.Linear(embed_dim, embed_dim)
  280. try:
  281. # use cuda op
  282. from deformable_detr_ops import ms_deformable_attn
  283. self.ms_deformable_attn_core = ms_deformable_attn
  284. except:
  285. # use paddle func
  286. self.ms_deformable_attn_core = deformable_attention_core_func
  287. self._reset_parameters()
  288. def _reset_parameters(self):
  289. # sampling_offsets
  290. constant_(self.sampling_offsets.weight)
  291. thetas = paddle.arange(
  292. self.num_heads,
  293. dtype=paddle.float32) * (2.0 * math.pi / self.num_heads)
  294. grid_init = paddle.stack([thetas.cos(), thetas.sin()], -1)
  295. grid_init = grid_init / grid_init.abs().max(-1, keepdim=True)
  296. grid_init = grid_init.reshape([self.num_heads, 1, 1, 2]).tile(
  297. [1, self.num_levels, self.num_points, 1])
  298. scaling = paddle.arange(
  299. 1, self.num_points + 1,
  300. dtype=paddle.float32).reshape([1, 1, -1, 1])
  301. grid_init *= scaling
  302. self.sampling_offsets.bias.set_value(grid_init.flatten())
  303. # attention_weights
  304. constant_(self.attention_weights.weight)
  305. constant_(self.attention_weights.bias)
  306. # proj
  307. xavier_uniform_(self.value_proj.weight)
  308. constant_(self.value_proj.bias)
  309. xavier_uniform_(self.output_proj.weight)
  310. constant_(self.output_proj.bias)
  311. def forward(self,
  312. query,
  313. reference_points,
  314. value,
  315. value_spatial_shapes,
  316. value_level_start_index,
  317. value_mask=None):
  318. """
  319. Args:
  320. query (Tensor): [bs, query_length, C]
  321. reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
  322. bottom-right (1, 1), including padding area
  323. value (Tensor): [bs, value_length, C]
  324. value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
  325. value_level_start_index (Tensor(int64)): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...]
  326. value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
  327. Returns:
  328. output (Tensor): [bs, Length_{query}, C]
  329. """
  330. bs, Len_q = query.shape[:2]
  331. Len_v = value.shape[1]
  332. assert int(value_spatial_shapes.prod(1).sum()) == Len_v
  333. value = self.value_proj(value)
  334. if value_mask is not None:
  335. value_mask = value_mask.astype(value.dtype).unsqueeze(-1)
  336. value *= value_mask
  337. value = value.reshape([bs, Len_v, self.num_heads, self.head_dim])
  338. sampling_offsets = self.sampling_offsets(query).reshape(
  339. [bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2])
  340. attention_weights = self.attention_weights(query).reshape(
  341. [bs, Len_q, self.num_heads, self.num_levels * self.num_points])
  342. attention_weights = F.softmax(attention_weights).reshape(
  343. [bs, Len_q, self.num_heads, self.num_levels, self.num_points])
  344. if reference_points.shape[-1] == 2:
  345. offset_normalizer = value_spatial_shapes.flip([1]).reshape(
  346. [1, 1, 1, self.num_levels, 1, 2])
  347. sampling_locations = reference_points.reshape([
  348. bs, Len_q, 1, self.num_levels, 1, 2
  349. ]) + sampling_offsets / offset_normalizer
  350. elif reference_points.shape[-1] == 4:
  351. sampling_locations = (
  352. reference_points[:, :, None, :, None, :2] + sampling_offsets /
  353. self.num_points * reference_points[:, :, None, :, None, 2:] *
  354. 0.5)
  355. else:
  356. raise ValueError(
  357. "Last dim of reference_points must be 2 or 4, but get {} instead.".
  358. format(reference_points.shape[-1]))
  359. output = self.ms_deformable_attn_core(
  360. value, value_spatial_shapes, value_level_start_index,
  361. sampling_locations, attention_weights)
  362. output = self.output_proj(output)
  363. return output
  364. ## Transformer Encoder layer
  365. class TransformerEncoderLayer(nn.Module):
  366. def __init__(self,
  367. d_model :int = 256,
  368. num_heads :int = 8,
  369. mlp_ratio :float = 4.0,
  370. dropout :float = 0.1,
  371. act_type :str = "relu",
  372. ):
  373. super().__init__()
  374. # ----------- Basic parameters -----------
  375. self.d_model = d_model
  376. self.num_heads = num_heads
  377. self.mlp_ratio = mlp_ratio
  378. self.dropout = dropout
  379. self.act_type = act_type
  380. # ----------- Basic parameters -----------
  381. # Multi-head Self-Attn
  382. self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
  383. self.dropout = nn.Dropout(dropout)
  384. self.norm = nn.LayerNorm(d_model)
  385. # Feedforwaed Network
  386. self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
  387. def with_pos_embed(self, tensor, pos):
  388. return tensor if pos is None else tensor + pos
  389. def forward(self, src, pos_embed):
  390. """
  391. Input:
  392. src: [torch.Tensor] -> [B, N, C]
  393. pos_embed: [torch.Tensor] -> [B, N, C]
  394. Output:
  395. src: [torch.Tensor] -> [B, N, C]
  396. """
  397. q = k = self.with_pos_embed(src, pos_embed)
  398. # -------------- MHSA --------------
  399. src2 = self.self_attn(q, k, value=src)[0]
  400. src = src + self.dropout(src2)
  401. src = self.norm(src)
  402. # -------------- FFN --------------
  403. src = self.ffn(src)
  404. return src
  405. ## Transformer Encoder
  406. class TransformerEncoder(nn.Module):
  407. def __init__(self,
  408. d_model :int = 256,
  409. num_heads :int = 8,
  410. num_layers :int = 1,
  411. mlp_ratio :float = 4.0,
  412. pe_temperature : float = 10000.,
  413. dropout :float = 0.1,
  414. act_type :str = "relu",
  415. ):
  416. super().__init__()
  417. # ----------- Basic parameters -----------
  418. self.d_model = d_model
  419. self.num_heads = num_heads
  420. self.num_layers = num_layers
  421. self.mlp_ratio = mlp_ratio
  422. self.dropout = dropout
  423. self.act_type = act_type
  424. self.pe_temperature = pe_temperature
  425. self.pos_embed = None
  426. # ----------- Basic parameters -----------
  427. self.encoder_layers = get_clones(
  428. TransformerEncoderLayer(d_model, num_heads, mlp_ratio, dropout, act_type), num_layers)
  429. def build_2d_sincos_position_embedding(self, w, h, embed_dim=256, temperature=10000.):
  430. assert embed_dim % 4 == 0, \
  431. 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
  432. # ----------- Check cahed pos_embed -----------
  433. if self.pos_embed is not None and \
  434. self.pos_embed.shape[2:] == [h, w]:
  435. return self.pos_embed
  436. # ----------- Generate grid coords -----------
  437. grid_w = torch.arange(int(w), dtype=torch.float32)
  438. grid_h = torch.arange(int(h), dtype=torch.float32)
  439. grid_w, grid_h = torch.meshgrid([grid_w, grid_h]) # shape: [H, W]
  440. pos_dim = embed_dim // 4
  441. omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
  442. omega = 1. / (temperature**omega)
  443. out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
  444. out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
  445. # shape: [1, N, C]
  446. pos_embed = torch.concat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], axis=1)[None, :, :]
  447. self.pos_embed = pos_embed
  448. return pos_embed
  449. def forward(self, src):
  450. """
  451. Input:
  452. src: [torch.Tensor] -> [B, C, H, W]
  453. Output:
  454. src: [torch.Tensor] -> [B, N, C]
  455. """
  456. # -------- Transformer encoder --------
  457. for encoder in self.encoder_layers:
  458. channels, fmp_h, fmp_w = src.shape[1:]
  459. # [B, C, H, W] -> [B, N, C], N=HxW
  460. src_flatten = src.flatten(2).permute(0, 2, 1)
  461. pos_embed = self.build_2d_sincos_position_embedding(
  462. fmp_w, fmp_h, channels, self.pe_temperature)
  463. memory = encoder(src_flatten, pos_embed=pos_embed)
  464. # [B, N, C] -> [B, C, N] -> [B, C, H, W]
  465. src = memory.permute(0, 2, 1).reshape([-1, channels, fmp_h, fmp_w])
  466. return src
  467. ## Transformer Decoder layer
  468. class TransformerDecoderLayer(nn.Module):
  469. def __init__(self,
  470. d_model :int = 256,
  471. num_heads :int = 8,
  472. num_levels :int = 3,
  473. num_points :int = 4,
  474. mlp_ratio :float = 4.0,
  475. dropout :float = 0.1,
  476. act_type :str = "relu",
  477. ):
  478. super().__init__()
  479. # ----------- Basic parameters -----------
  480. self.d_model = d_model
  481. self.num_heads = num_heads
  482. self.num_levels = num_levels
  483. self.num_points = num_points
  484. self.mlp_ratio = mlp_ratio
  485. self.dropout = dropout
  486. self.act_type = act_type
  487. # ---------------- Network parameters ----------------
  488. ## Multi-head Self-Attn
  489. self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
  490. self.dropout1 = nn.Dropout(dropout)
  491. self.norm1 = nn.LayerNorm(d_model)
  492. ## CrossAttention
  493. self.cross_attn = MSDeformableAttention(d_model, num_heads, num_levels, num_points, 1.0)
  494. self.dropout2 = nn.Dropout(dropout)
  495. self.norm2 = nn.LayerNorm(d_model)
  496. ## FFN
  497. self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
  498. def with_pos_embed(self, tensor, pos):
  499. return tensor if pos is None else tensor + pos
  500. def forward(self,
  501. tgt,
  502. reference_points,
  503. memory,
  504. memory_spatial_shapes,
  505. memory_level_start_index,
  506. attn_mask=None,
  507. memory_mask=None,
  508. query_pos_embed=None):
  509. # ---------------- MSHA for Object Query -----------------
  510. q = k = self.with_pos_embed(tgt, query_pos_embed)
  511. if attn_mask is not None:
  512. attn_mask = torch.where(
  513. attn_mask.astype('bool'),
  514. torch.zeros(attn_mask.shape, tgt.dtype),
  515. torch.full(attn_mask.shape, float("-inf"), tgt.dtype))
  516. tgt2 = self.self_attn(q, k, value=tgt, attn_mask=attn_mask)
  517. tgt = tgt + self.dropout1(tgt2)
  518. tgt = self.norm1(tgt)
  519. # ---------------- CMHA for Object Query and Image-feature -----------------
  520. tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos_embed),
  521. reference_points,
  522. memory,
  523. memory_spatial_shapes,
  524. memory_level_start_index,
  525. memory_mask)
  526. tgt = tgt + self.dropout2(tgt2)
  527. tgt = self.norm2(tgt)
  528. # ---------------- FeedForward Network -----------------
  529. tgt = self.ffn(tgt)
  530. return tgt
  531. ## Transformer Decoder
  532. class TransformerDecoder(nn.Module):
  533. def __init__(self,
  534. d_model :int = 256,
  535. num_heads :int = 8,
  536. num_layers :int = 1,
  537. mlp_ratio :float = 4.0,
  538. pe_temperature :float = 10000.,
  539. dropout :float = 0.1,
  540. act_type :str = "relu",
  541. ):
  542. super().__init__()
  543. # ----------- Basic parameters -----------
  544. self.d_model = d_model
  545. self.num_heads = num_heads
  546. self.num_layers = num_layers
  547. self.mlp_ratio = mlp_ratio
  548. self.dropout = dropout
  549. self.act_type = act_type
  550. self.pe_temperature = pe_temperature
  551. self.pos_embed = None
  552. # ----------- Basic parameters -----------
  553. self.decoder_layers = None
  554. def build_2d_sincos_position_embedding(self, w, h, embed_dim=256, temperature=10000.):
  555. assert embed_dim % 4 == 0, \
  556. 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
  557. # ----------- Check cahed pos_embed -----------
  558. if self.pos_embed is not None and \
  559. self.pos_embed.shape[2:] == [h, w]:
  560. return self.pos_embed
  561. # ----------- Generate grid coords -----------
  562. grid_w = torch.arange(int(w), dtype=torch.float32)
  563. grid_h = torch.arange(int(h), dtype=torch.float32)
  564. grid_w, grid_h = torch.meshgrid([grid_w, grid_h]) # shape: [H, W]
  565. pos_dim = embed_dim // 4
  566. omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
  567. omega = 1. / (temperature**omega)
  568. out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
  569. out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
  570. # shape: [1, N, C]
  571. pos_embed = torch.concat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], axis=1)[None, :, :]
  572. self.pos_embed = pos_embed
  573. return pos_embed