transformer.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. # https://github.com/facebookresearch/detr
  3. import torch
  4. import torch.nn as nn
  5. try:
  6. from .transformer_encoder import TransformerEncoderLayer, TransformerEncoder
  7. from .transformer_decoder import TransformerDecoderLayer, TransformerDecoder
  8. except:
  9. from transformer_encoder import TransformerEncoderLayer, TransformerEncoder
  10. from transformer_decoder import TransformerDecoderLayer, TransformerDecoder
  11. class DETRTransformer(nn.Module):
  12. def __init__(self,
  13. hidden_dim :int = 512,
  14. num_heads :int = 8,
  15. ffn_dim :int = 2048,
  16. num_enc_layers :int = 6,
  17. num_dec_layers :int = 6,
  18. dropout :float = 0.1,
  19. act_type :str = "relu",
  20. pre_norm :bool = False,
  21. return_intermediate_dec :bool = False):
  22. super().__init__()
  23. # ---------- Basic parameters ----------
  24. self.hidden_dim = hidden_dim
  25. self.num_heads = num_heads
  26. self.ffn_dim = ffn_dim
  27. self.act_type = act_type
  28. self.pre_norm = pre_norm
  29. self.num_enc_layers = num_enc_layers
  30. self.num_dec_layers = num_dec_layers
  31. self.return_intermediate_dec = return_intermediate_dec
  32. # ---------- Model parameters ----------
  33. ## Encoder module
  34. encoder_layer = TransformerEncoderLayer(
  35. hidden_dim, num_heads, ffn_dim, dropout, act_type, pre_norm)
  36. encoder_norm = nn.LayerNorm(hidden_dim) if pre_norm else None
  37. self.encoder = TransformerEncoder(encoder_layer, num_enc_layers, encoder_norm)
  38. ## Decoder module
  39. decoder_layer = TransformerDecoderLayer(
  40. hidden_dim, num_heads, ffn_dim, dropout, act_type, pre_norm)
  41. decoder_norm = nn.LayerNorm(hidden_dim)
  42. self.decoder = TransformerDecoder(decoder_layer, num_dec_layers, decoder_norm,
  43. return_intermediate=return_intermediate_dec)
  44. self._reset_parameters()
  45. def _reset_parameters(self):
  46. for p in self.parameters():
  47. if p.dim() > 1:
  48. nn.init.xavier_uniform_(p)
  49. def get_posembed(self, embed_dim, src_mask, temperature=10000, normalize=False):
  50. scale = 2 * torch.pi
  51. num_pos_feats = embed_dim // 2
  52. not_mask = ~src_mask
  53. # [B, H, W]
  54. y_embed = not_mask.cumsum(1, dtype=torch.float32)
  55. x_embed = not_mask.cumsum(2, dtype=torch.float32)
  56. # normalize grid coords
  57. if normalize:
  58. y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * scale
  59. x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * scale
  60. dim_t = torch.arange(num_pos_feats, dtype=torch.float32)
  61. dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
  62. dim_t = temperature ** (2 * dim_t_)
  63. pos_x = torch.div(x_embed[..., None], dim_t)
  64. pos_y = torch.div(y_embed[..., None], dim_t)
  65. pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
  66. pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
  67. # [B, H, W, C] -> [B, C, H, W]
  68. pos_embed = torch.cat((pos_y, pos_x), dim=-1).permute(0, 3, 1, 2)
  69. return pos_embed
  70. def forward(self, src, src_mask, query_embed):
  71. bs, c, h, w = src.shape
  72. # Get position embedding
  73. pos_embed = self.get_posembed(c, src_mask, normalize=True)
  74. # reshape: [B, C, H, W] -> [N, B, C], H=HW
  75. src = src.flatten(2).permute(2, 0, 1)
  76. pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
  77. src_mask = src_mask.flatten(1)
  78. # [Nq, C] -> [Nq, B, C]
  79. query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
  80. # Encoder
  81. memory = self.encoder(src, src_key_padding_mask=src_mask, pos_embed=pos_embed)
  82. # Decoder
  83. tgt = torch.zeros_like(query_embed)
  84. hs = self.decoder(tgt = tgt,
  85. tgt_mask = None,
  86. memory = memory,
  87. memory_mask = src_mask,
  88. memory_pos = pos_embed,
  89. query_pos = query_embed)
  90. # [M, Nq, B, C] -> [M, B, Nq, C]
  91. hs = hs.transpose(1, 2)
  92. # [N, B, C] -> [B, C, N] -> [B, C, H, W]
  93. memory = memory.permute(1, 2, 0).view(bs, c, h, w)
  94. return hs, memory