basic.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715
  1. import math
  2. import copy
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from torch.nn.init import constant_, xavier_uniform_
  7. def get_clones(module, N):
  8. if N <= 0:
  9. return None
  10. else:
  11. return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
  12. def inverse_sigmoid(x, eps=1e-5):
  13. x = x.clamp(min=0., max=1.)
  14. return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
  15. # ----------------- MLP modules -----------------
  16. class MLP(nn.Module):
  17. def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
  18. super().__init__()
  19. self.num_layers = num_layers
  20. h = [hidden_dim] * (num_layers - 1)
  21. self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([in_dim] + h, h + [out_dim]))
  22. def forward(self, x):
  23. for i, layer in enumerate(self.layers):
  24. x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  25. return x
  26. class FFN(nn.Module):
  27. def __init__(self, d_model=256, mlp_ratio=4.0, dropout=0., act_type='relu'):
  28. super().__init__()
  29. self.fpn_dim = round(d_model * mlp_ratio)
  30. self.linear1 = nn.Linear(d_model, self.fpn_dim)
  31. self.activation = get_activation(act_type)
  32. self.dropout2 = nn.Dropout(dropout)
  33. self.linear2 = nn.Linear(self.fpn_dim, d_model)
  34. self.dropout3 = nn.Dropout(dropout)
  35. self.norm = nn.LayerNorm(d_model)
  36. def forward(self, src):
  37. src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
  38. src = src + self.dropout3(src2)
  39. src = self.norm(src)
  40. return src
  41. # ----------------- Basic CNN Ops -----------------
  42. def get_conv2d(c1, c2, k, p, s, g, bias=False):
  43. conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, groups=g, bias=bias)
  44. return conv
  45. def get_activation(act_type=None):
  46. if act_type == 'relu':
  47. return nn.ReLU(inplace=True)
  48. elif act_type == 'lrelu':
  49. return nn.LeakyReLU(0.1, inplace=True)
  50. elif act_type == 'mish':
  51. return nn.Mish(inplace=True)
  52. elif act_type == 'silu':
  53. return nn.SiLU(inplace=True)
  54. elif act_type == 'gelu':
  55. return nn.GELU()
  56. elif act_type is None:
  57. return nn.Identity()
  58. else:
  59. raise NotImplementedError
  60. def get_norm(norm_type, dim):
  61. if norm_type == 'BN':
  62. return nn.BatchNorm2d(dim)
  63. elif norm_type == 'GN':
  64. return nn.GroupNorm(num_groups=32, num_channels=dim)
  65. elif norm_type is None:
  66. return nn.Identity()
  67. else:
  68. raise NotImplementedError
  69. def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
  70. """3x3 convolution with padding"""
  71. return nn.Conv2d(
  72. in_planes,
  73. out_planes,
  74. kernel_size=3,
  75. stride=stride,
  76. padding=dilation,
  77. groups=groups,
  78. bias=False,
  79. dilation=dilation,
  80. )
  81. def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
  82. """1x1 convolution"""
  83. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
  84. class FrozenBatchNorm2d(torch.nn.Module):
  85. def __init__(self, n):
  86. super(FrozenBatchNorm2d, self).__init__()
  87. self.register_buffer("weight", torch.ones(n))
  88. self.register_buffer("bias", torch.zeros(n))
  89. self.register_buffer("running_mean", torch.zeros(n))
  90. self.register_buffer("running_var", torch.ones(n))
  91. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  92. missing_keys, unexpected_keys, error_msgs):
  93. num_batches_tracked_key = prefix + 'num_batches_tracked'
  94. if num_batches_tracked_key in state_dict:
  95. del state_dict[num_batches_tracked_key]
  96. super(FrozenBatchNorm2d, self)._load_from_state_dict(
  97. state_dict, prefix, local_metadata, strict,
  98. missing_keys, unexpected_keys, error_msgs)
  99. def forward(self, x):
  100. # move reshapes to the beginning
  101. # to make it fuser-friendly
  102. w = self.weight.reshape(1, -1, 1, 1)
  103. b = self.bias.reshape(1, -1, 1, 1)
  104. rv = self.running_var.reshape(1, -1, 1, 1)
  105. rm = self.running_mean.reshape(1, -1, 1, 1)
  106. eps = 1e-5
  107. scale = w * (rv + eps).rsqrt()
  108. bias = b - rm * scale
  109. return x * scale + bias
  110. class BasicConv(nn.Module):
  111. def __init__(self,
  112. in_dim, # in channels
  113. out_dim, # out channels
  114. kernel_size=1, # kernel size
  115. padding=0, # padding
  116. stride=1, # padding
  117. act_type :str = 'lrelu', # activation
  118. norm_type :str = 'BN', # normalization
  119. ):
  120. super(BasicConv, self).__init__()
  121. add_bias = False if norm_type else True
  122. self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
  123. self.norm = get_norm(norm_type, out_dim)
  124. self.act = get_activation(act_type)
  125. def forward(self, x):
  126. return self.act(self.norm(self.conv(x)))
  127. class DepthwiseConv(nn.Module):
  128. def __init__(self,
  129. in_dim, # in channels
  130. out_dim, # out channels
  131. kernel_size=1, # kernel size
  132. padding=0, # padding
  133. stride=1, # padding
  134. act_type :str = None, # activation
  135. norm_type :str = 'BN', # normalization
  136. ):
  137. super(DepthwiseConv, self).__init__()
  138. assert in_dim == out_dim
  139. add_bias = False if norm_type else True
  140. self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=out_dim, bias=add_bias)
  141. self.norm = get_norm(norm_type, out_dim)
  142. self.act = get_activation(act_type)
  143. def forward(self, x):
  144. return self.act(self.norm(self.conv(x)))
  145. class PointwiseConv(nn.Module):
  146. def __init__(self,
  147. in_dim, # in channels
  148. out_dim, # out channels
  149. act_type :str = 'lrelu', # activation
  150. norm_type :str = 'BN', # normalization
  151. ):
  152. super(DepthwiseConv, self).__init__()
  153. assert in_dim == out_dim
  154. add_bias = False if norm_type else True
  155. self.conv = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, g=1, bias=add_bias)
  156. self.norm = get_norm(norm_type, out_dim)
  157. self.act = get_activation(act_type)
  158. def forward(self, x):
  159. return self.act(self.norm(self.conv(x)))
  160. # ----------------- CNN Modules -----------------
  161. class Bottleneck(nn.Module):
  162. def __init__(self,
  163. in_dim,
  164. out_dim,
  165. expand_ratio = 0.5,
  166. kernel_sizes = [3, 3],
  167. shortcut = True,
  168. act_type = 'silu',
  169. norm_type = 'BN',
  170. depthwise = False,):
  171. super(Bottleneck, self).__init__()
  172. inter_dim = int(out_dim * expand_ratio)
  173. if depthwise:
  174. self.cv1 = nn.Sequential(
  175. DepthwiseConv(in_dim, in_dim, kernel_size=kernel_sizes[0], padding=kernel_sizes[0]//2, act_type=act_type, norm_type=norm_type),
  176. PointwiseConv(in_dim, inter_dim, act_type=act_type, norm_type=norm_type),
  177. )
  178. self.cv2 = nn.Sequential(
  179. DepthwiseConv(inter_dim, inter_dim, kernel_size=kernel_sizes[1], padding=kernel_sizes[1]//2, act_type=act_type, norm_type=norm_type),
  180. PointwiseConv(inter_dim, out_dim, act_type=act_type, norm_type=norm_type),
  181. )
  182. else:
  183. self.cv1 = BasicConv(in_dim, inter_dim, kernel_size=kernel_sizes[0], padding=kernel_sizes[0]//2, act_type=act_type, norm_type=norm_type)
  184. self.cv2 = BasicConv(inter_dim, out_dim, kernel_size=kernel_sizes[1], padding=kernel_sizes[1]//2, act_type=act_type, norm_type=norm_type)
  185. self.shortcut = shortcut and in_dim == out_dim
  186. def forward(self, x):
  187. h = self.cv2(self.cv1(x))
  188. return x + h if self.shortcut else h
  189. class RTCBlock(nn.Module):
  190. def __init__(self,
  191. in_dim,
  192. out_dim,
  193. num_blocks = 1,
  194. shortcut = False,
  195. act_type = 'silu',
  196. norm_type = 'BN',
  197. depthwise = False,):
  198. super(RTCBlock, self).__init__()
  199. self.inter_dim = out_dim // 2
  200. self.input_proj = BasicConv(in_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  201. self.m = nn.Sequential(*(
  202. Bottleneck(self.inter_dim, self.inter_dim, 1.0, [3, 3], shortcut, act_type, norm_type, depthwise)
  203. for _ in range(num_blocks)))
  204. self.output_proj = BasicConv((2 + num_blocks) * self.inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  205. def forward(self, x):
  206. # Input proj
  207. x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
  208. out = list([x1, x2])
  209. # Bottlenecl
  210. out.extend(m(out[-1]) for m in self.m)
  211. # Output proj
  212. out = self.output_proj(torch.cat(out, dim=1))
  213. return out
  214. # ----------------- Basic Transformer Ops -----------------
  215. def multi_scale_deformable_attn_pytorch(
  216. value: torch.Tensor,
  217. value_spatial_shapes: torch.Tensor,
  218. sampling_locations: torch.Tensor,
  219. attention_weights: torch.Tensor,
  220. ) -> torch.Tensor:
  221. bs, _, num_heads, embed_dims = value.shape
  222. _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
  223. value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
  224. sampling_grids = 2 * sampling_locations - 1
  225. sampling_value_list = []
  226. for level, (H_, W_) in enumerate(value_spatial_shapes):
  227. # bs, H_*W_, num_heads, embed_dims ->
  228. # bs, H_*W_, num_heads*embed_dims ->
  229. # bs, num_heads*embed_dims, H_*W_ ->
  230. # bs*num_heads, embed_dims, H_, W_
  231. value_l_ = (
  232. value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
  233. )
  234. # bs, num_queries, num_heads, num_points, 2 ->
  235. # bs, num_heads, num_queries, num_points, 2 ->
  236. # bs*num_heads, num_queries, num_points, 2
  237. sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
  238. # bs*num_heads, embed_dims, num_queries, num_points
  239. sampling_value_l_ = F.grid_sample(
  240. value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
  241. )
  242. sampling_value_list.append(sampling_value_l_)
  243. # (bs, num_queries, num_heads, num_levels, num_points) ->
  244. # (bs, num_heads, num_queries, num_levels, num_points) ->
  245. # (bs, num_heads, 1, num_queries, num_levels*num_points)
  246. attention_weights = attention_weights.transpose(1, 2).reshape(
  247. bs * num_heads, 1, num_queries, num_levels * num_points
  248. )
  249. output = (
  250. (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
  251. .sum(-1)
  252. .view(bs, num_heads * embed_dims, num_queries)
  253. )
  254. return output.transpose(1, 2).contiguous()
  255. class MSDeformableAttention(nn.Module):
  256. def __init__(self,
  257. embed_dim=256,
  258. num_heads=8,
  259. num_levels=4,
  260. num_points=4):
  261. """
  262. Multi-Scale Deformable Attention Module
  263. """
  264. super(MSDeformableAttention, self).__init__()
  265. self.embed_dim = embed_dim
  266. self.num_heads = num_heads
  267. self.num_levels = num_levels
  268. self.num_points = num_points
  269. self.total_points = num_heads * num_levels * num_points
  270. self.head_dim = embed_dim // num_heads
  271. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  272. self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2)
  273. self.attention_weights = nn.Linear(embed_dim, self.total_points)
  274. self.value_proj = nn.Linear(embed_dim, embed_dim)
  275. self.output_proj = nn.Linear(embed_dim, embed_dim)
  276. try:
  277. # use cuda op
  278. from deformable_detr_ops import ms_deformable_attn
  279. self.ms_deformable_attn_core = ms_deformable_attn
  280. except:
  281. # use torch func
  282. self.ms_deformable_attn_core = multi_scale_deformable_attn_pytorch
  283. self._reset_parameters()
  284. def _reset_parameters(self):
  285. """
  286. Default initialization for Parameters of Module.
  287. """
  288. constant_(self.sampling_offsets.weight.data, 0.0)
  289. thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
  290. 2.0 * math.pi / self.num_heads
  291. )
  292. grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
  293. grid_init = (
  294. (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
  295. .view(self.num_heads, 1, 1, 2)
  296. .repeat(1, self.num_levels, self.num_points, 1)
  297. )
  298. for i in range(self.num_points):
  299. grid_init[:, :, i, :] *= i + 1
  300. with torch.no_grad():
  301. self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
  302. constant_(self.attention_weights.weight.data, 0.0)
  303. constant_(self.attention_weights.bias.data, 0.0)
  304. xavier_uniform_(self.value_proj.weight.data)
  305. constant_(self.value_proj.bias.data, 0.0)
  306. xavier_uniform_(self.output_proj.weight.data)
  307. constant_(self.output_proj.bias.data, 0.0)
  308. def forward(self,
  309. query,
  310. reference_points,
  311. value,
  312. value_spatial_shapes,
  313. value_mask=None):
  314. """
  315. Args:
  316. query (Tensor): [bs, query_length, C]
  317. reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
  318. bottom-right (1, 1), including padding area
  319. value (Tensor): [bs, value_length, C]
  320. value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
  321. value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
  322. Returns:
  323. output (Tensor): [bs, Length_{query}, C]
  324. """
  325. bs, num_query = query.shape[:2]
  326. num_value = value.shape[1]
  327. assert int(value_spatial_shapes.prod(1).sum()) == num_value
  328. # Value projection
  329. value = self.value_proj(value)
  330. # fill "0" for the padding part
  331. if value_mask is not None:
  332. value_mask = value_mask.astype(value.dtype).unsqueeze(-1)
  333. value *= value_mask
  334. # [bs, all_hw, 256] -> [bs, all_hw, num_head, head_dim]
  335. value = value.reshape([bs, num_value, self.num_heads, -1])
  336. # [bs, all_hw, num_head, nun_level, num_sample_point, num_offset]
  337. sampling_offsets = self.sampling_offsets(query).reshape(
  338. [bs, num_query, self.num_heads, self.num_levels, self.num_points, 2])
  339. # [bs, all_hw, num_head, nun_level*num_sample_point]
  340. attention_weights = self.attention_weights(query).reshape(
  341. [bs, num_query, self.num_heads, self.num_levels * self.num_points])
  342. attention_weights = attention_weights.softmax(-1)
  343. # [bs, all_hw, num_head, nun_level, num_sample_point]
  344. attention_weights = attention_weights.reshape(
  345. [bs, num_query, self.num_heads, self.num_levels, self.num_points])
  346. # [bs, num_query, num_heads, num_levels, num_points, 2]
  347. if reference_points.shape[-1] == 2:
  348. # reference_points [bs, all_hw, num_sample_point, 2] -> [bs, all_hw, 1, num_sample_point, 1, 2]
  349. # sampling_offsets [bs, all_hw, nun_head, num_level, num_sample_point, 2]
  350. # offset_normalizer [4, 2] -> [1, 1, 1, num_sample_point, 1, 2]
  351. # references_points + sampling_offsets
  352. offset_normalizer = value_spatial_shapes.flip([1]).reshape(
  353. [1, 1, 1, self.num_levels, 1, 2])
  354. sampling_locations = (
  355. reference_points[:, :, None, :, None, :]
  356. + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
  357. )
  358. elif reference_points.shape[-1] == 4:
  359. sampling_locations = (
  360. reference_points[:, :, None, :, None, :2]
  361. + sampling_offsets
  362. / self.num_points
  363. * reference_points[:, :, None, :, None, 2:]
  364. * 0.5)
  365. else:
  366. raise ValueError(
  367. "Last dim of reference_points must be 2 or 4, but get {} instead.".
  368. format(reference_points.shape[-1]))
  369. # Multi-scale Deformable attention
  370. output = self.ms_deformable_attn_core(
  371. value, value_spatial_shapes, sampling_locations, attention_weights)
  372. # Output project
  373. output = self.output_proj(output)
  374. return output
  375. # ----------------- Transformer modules -----------------
  376. ## Transformer Encoder layer
  377. class TransformerEncoderLayer(nn.Module):
  378. def __init__(self,
  379. d_model :int = 256,
  380. num_heads :int = 8,
  381. mlp_ratio :float = 4.0,
  382. dropout :float = 0.1,
  383. act_type :str = "relu",
  384. ):
  385. super().__init__()
  386. # ----------- Basic parameters -----------
  387. self.d_model = d_model
  388. self.num_heads = num_heads
  389. self.mlp_ratio = mlp_ratio
  390. self.dropout = dropout
  391. self.act_type = act_type
  392. # ----------- Basic parameters -----------
  393. # Multi-head Self-Attn
  394. self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
  395. self.dropout = nn.Dropout(dropout)
  396. self.norm = nn.LayerNorm(d_model)
  397. # Feedforwaed Network
  398. self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
  399. def with_pos_embed(self, tensor, pos):
  400. return tensor if pos is None else tensor + pos
  401. def forward(self, src, pos_embed):
  402. """
  403. Input:
  404. src: [torch.Tensor] -> [B, N, C]
  405. pos_embed: [torch.Tensor] -> [B, N, C]
  406. Output:
  407. src: [torch.Tensor] -> [B, N, C]
  408. """
  409. q = k = self.with_pos_embed(src, pos_embed)
  410. # -------------- MHSA --------------
  411. src2 = self.self_attn(q, k, value=src)[0]
  412. src = src + self.dropout(src2)
  413. src = self.norm(src)
  414. # -------------- FFN --------------
  415. src = self.ffn(src)
  416. return src
  417. ## Transformer Encoder
  418. class TransformerEncoder(nn.Module):
  419. def __init__(self,
  420. d_model :int = 256,
  421. num_heads :int = 8,
  422. num_layers :int = 1,
  423. mlp_ratio :float = 4.0,
  424. pe_temperature : float = 10000.,
  425. dropout :float = 0.1,
  426. act_type :str = "relu",
  427. ):
  428. super().__init__()
  429. # ----------- Basic parameters -----------
  430. self.d_model = d_model
  431. self.num_heads = num_heads
  432. self.num_layers = num_layers
  433. self.mlp_ratio = mlp_ratio
  434. self.dropout = dropout
  435. self.act_type = act_type
  436. self.pe_temperature = pe_temperature
  437. self.pos_embed = None
  438. # ----------- Basic parameters -----------
  439. self.encoder_layers = get_clones(
  440. TransformerEncoderLayer(d_model, num_heads, mlp_ratio, dropout, act_type), num_layers)
  441. def build_2d_sincos_position_embedding(self, w, h, embed_dim=256, temperature=10000.):
  442. assert embed_dim % 4 == 0, \
  443. 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
  444. # ----------- Check cahed pos_embed -----------
  445. if self.pos_embed is not None and \
  446. self.pos_embed.shape[2:] == [h, w]:
  447. return self.pos_embed
  448. # ----------- Generate grid coords -----------
  449. grid_w = torch.arange(int(w), dtype=torch.float32)
  450. grid_h = torch.arange(int(h), dtype=torch.float32)
  451. grid_w, grid_h = torch.meshgrid([grid_w, grid_h]) # shape: [H, W]
  452. pos_dim = embed_dim // 4
  453. omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
  454. omega = 1. / (temperature**omega)
  455. out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
  456. out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
  457. # shape: [1, N, C]
  458. pos_embed = torch.concat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], axis=1)[None, :, :]
  459. self.pos_embed = pos_embed
  460. return pos_embed
  461. def forward(self, src):
  462. """
  463. Input:
  464. src: [torch.Tensor] -> [B, C, H, W]
  465. Output:
  466. src: [torch.Tensor] -> [B, N, C]
  467. """
  468. # -------- Transformer encoder --------
  469. for encoder in self.encoder_layers:
  470. channels, fmp_h, fmp_w = src.shape[1:]
  471. # [B, C, H, W] -> [B, N, C], N=HxW
  472. src_flatten = src.flatten(2).permute(0, 2, 1)
  473. pos_embed = self.build_2d_sincos_position_embedding(
  474. fmp_w, fmp_h, channels, self.pe_temperature)
  475. memory = encoder(src_flatten, pos_embed=pos_embed)
  476. # [B, N, C] -> [B, C, N] -> [B, C, H, W]
  477. src = memory.permute(0, 2, 1).reshape([-1, channels, fmp_h, fmp_w])
  478. return src
  479. ## Transformer Decoder layer
  480. class DeformableTransformerDecoderLayer(nn.Module):
  481. def __init__(self,
  482. d_model :int = 256,
  483. num_heads :int = 8,
  484. num_levels :int = 3,
  485. num_points :int = 4,
  486. mlp_ratio :float = 4.0,
  487. dropout :float = 0.1,
  488. act_type :str = "relu",
  489. ):
  490. super().__init__()
  491. # ----------- Basic parameters -----------
  492. self.d_model = d_model
  493. self.num_heads = num_heads
  494. self.num_levels = num_levels
  495. self.num_points = num_points
  496. self.mlp_ratio = mlp_ratio
  497. self.dropout = dropout
  498. self.act_type = act_type
  499. # ---------------- Network parameters ----------------
  500. ## Multi-head Self-Attn
  501. self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
  502. self.dropout1 = nn.Dropout(dropout)
  503. self.norm1 = nn.LayerNorm(d_model)
  504. ## CrossAttention
  505. self.cross_attn = MSDeformableAttention(d_model, num_heads, num_levels, num_points)
  506. self.dropout2 = nn.Dropout(dropout)
  507. self.norm2 = nn.LayerNorm(d_model)
  508. ## FFN
  509. self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
  510. def with_pos_embed(self, tensor, pos):
  511. return tensor if pos is None else tensor + pos
  512. def forward(self,
  513. tgt,
  514. reference_points,
  515. memory,
  516. memory_spatial_shapes,
  517. attn_mask=None,
  518. memory_mask=None,
  519. query_pos_embed=None):
  520. # ---------------- MSHA for Object Query -----------------
  521. q = k = self.with_pos_embed(tgt, query_pos_embed)
  522. if attn_mask is not None:
  523. attn_mask = torch.where(
  524. attn_mask.astype('bool'),
  525. torch.zeros(attn_mask.shape, tgt.dtype),
  526. torch.full(attn_mask.shape, float("-inf"), tgt.dtype))
  527. tgt2 = self.self_attn(q, k, value=tgt)
  528. tgt = tgt + self.dropout1(tgt2)
  529. tgt = self.norm1(tgt)
  530. # ---------------- CMHA for Object Query and Image-feature -----------------
  531. tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos_embed),
  532. reference_points,
  533. memory,
  534. memory_spatial_shapes,
  535. memory_mask)
  536. tgt = tgt + self.dropout2(tgt2)
  537. tgt = self.norm2(tgt)
  538. # ---------------- FeedForward Network -----------------
  539. tgt = self.ffn(tgt)
  540. return tgt
  541. ## Transformer Decoder
  542. class DeformableTransformerDecoder(nn.Module):
  543. def __init__(self,
  544. d_model :int = 256,
  545. num_heads :int = 8,
  546. num_layers :int = 1,
  547. num_levels :int = 3,
  548. num_points :int = 4,
  549. mlp_ratio :float = 4.0,
  550. pe_temperature :float = 10000.,
  551. dropout :float = 0.1,
  552. act_type :str = "relu",
  553. return_intermediate :bool = False,
  554. ):
  555. super().__init__()
  556. # ----------- Basic parameters -----------
  557. self.d_model = d_model
  558. self.num_heads = num_heads
  559. self.num_layers = num_layers
  560. self.mlp_ratio = mlp_ratio
  561. self.dropout = dropout
  562. self.act_type = act_type
  563. self.pe_temperature = pe_temperature
  564. self.pos_embed = None
  565. # ----------- Network parameters -----------
  566. self.decoder_layers = get_clones(
  567. DeformableTransformerDecoderLayer(d_model, num_heads, num_levels, num_points, mlp_ratio, dropout, act_type), num_layers)
  568. self.num_layers = num_layers
  569. self.return_intermediate = return_intermediate
  570. def build_2d_sincos_position_embedding(self, w, h, embed_dim=256, temperature=10000.):
  571. assert embed_dim % 4 == 0, \
  572. 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
  573. # ----------- Check cahed pos_embed -----------
  574. if self.pos_embed is not None and \
  575. self.pos_embed.shape[2:] == [h, w]:
  576. return self.pos_embed
  577. # ----------- Generate grid coords -----------
  578. grid_w = torch.arange(int(w), dtype=torch.float32)
  579. grid_h = torch.arange(int(h), dtype=torch.float32)
  580. grid_w, grid_h = torch.meshgrid([grid_w, grid_h]) # shape: [H, W]
  581. pos_dim = embed_dim // 4
  582. omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
  583. omega = 1. / (temperature**omega)
  584. out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
  585. out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
  586. # shape: [1, N, C]
  587. pos_embed = torch.concat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], axis=1)[None, :, :]
  588. self.pos_embed = pos_embed
  589. return pos_embed
  590. def forward(self,
  591. tgt,
  592. ref_points_unact,
  593. memory,
  594. memory_spatial_shapes,
  595. bbox_head,
  596. score_head,
  597. query_pos_head,
  598. attn_mask=None,
  599. memory_mask=None):
  600. output = tgt
  601. dec_out_bboxes = []
  602. dec_out_logits = []
  603. ref_points_detach = F.sigmoid(ref_points_unact)
  604. for i, layer in enumerate(self.decoder_layers):
  605. ref_points_input = ref_points_detach.unsqueeze(2)
  606. query_pos_embed = query_pos_head(ref_points_detach)
  607. output = layer(output, ref_points_input, memory,
  608. memory_spatial_shapes, attn_mask,
  609. memory_mask, query_pos_embed)
  610. inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
  611. ref_points_detach))
  612. dec_out_logits.append(score_head[i](output))
  613. if i == 0:
  614. dec_out_bboxes.append(inter_ref_bbox)
  615. else:
  616. dec_out_bboxes.append(
  617. F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
  618. ref_points)))
  619. ref_points = inter_ref_bbox
  620. ref_points_detach = inter_ref_bbox.detach()
  621. return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)