rtdetr_decoder.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import torch
  2. import torch.nn as nn
  3. from .rtdetr_basic import get_clones, TRDecoderLayer, MLP
  4. # Transformer Decoder Module
  5. class TransformerDecoder(nn.Module):
  6. def __init__(self, cfg, in_dim, return_intermediate=False):
  7. super().__init__()
  8. # -------------------- Basic Parameters ---------------------
  9. self.d_model = in_dim
  10. self.query_dim = 4 # For RefPoint head
  11. self.scale = 2 * 3.141592653589793
  12. self.num_queries = cfg['num_queries']
  13. self.num_deocder_layers = cfg['num_decoder_layers']
  14. self.return_intermediate = return_intermediate
  15. self.ffn_dim = round(cfg['de_dim_feedforward']*cfg['width'])
  16. # -------------------- Network Parameters ---------------------
  17. ## Decoder
  18. decoder_layer = TRDecoderLayer(
  19. d_model=in_dim,
  20. dim_feedforward=self.ffn_dim,
  21. num_heads=cfg['de_num_heads'],
  22. dropout=cfg['de_dropout'],
  23. act_type=cfg['de_act']
  24. )
  25. self.decoder_layers = get_clones(decoder_layer, cfg['num_decoder_layers'])
  26. ## RefPoint Embed
  27. self.refpoint_embed = nn.Embedding(cfg['num_queries'], 4)
  28. self.ref_point_head = MLP(self.query_dim // 2 * in_dim, in_dim, in_dim, 2)
  29. ## Object Query Embed
  30. self.object_query = nn.Embedding(cfg['num_queries'], in_dim)
  31. nn.init.normal_(self.object_query.weight.data)
  32. ## TODO: Group queries
  33. self.bbox_embed = None
  34. self.class_embed = None
  35. def inverse_sigmoid(self, x):
  36. x = x.clamp(min=0, max=1)
  37. return torch.log(x.clamp(min=1e-5)/(1 - x).clamp(min=1e-5))
  38. def query_sine_embed(self, num_feats, reference_points):
  39. dim_t = torch.arange(num_feats, dtype=torch.float32, device=reference_points.device)
  40. dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_feats
  41. dim_t = 10000 ** (2 * dim_t_)
  42. x_embed = reference_points[:, :, 0] * self.scale
  43. y_embed = reference_points[:, :, 1] * self.scale
  44. pos_x = x_embed[:, :, None] / dim_t
  45. pos_y = y_embed[:, :, None] / dim_t
  46. pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
  47. pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
  48. w_embed = reference_points[:, :, 2] * self.scale
  49. pos_w = w_embed[:, :, None] / dim_t
  50. pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
  51. h_embed = reference_points[:, :, 3] * self.scale
  52. pos_h = h_embed[:, :, None] / dim_t
  53. pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
  54. query_sine_embed = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
  55. return query_sine_embed
  56. def forward(self, memory, memory_pos):
  57. bs, _, channels = memory.size()
  58. num_feats = channels // 2
  59. # prepare tgt & refpoint
  60. tgt = self.object_query.weight[None].repeat(bs, 1, 1)
  61. refpoint_embed = self.refpoint_embed.weight[None].repeat(bs, 1, 1)
  62. intermediate = []
  63. reference_points = refpoint_embed.sigmoid()
  64. ref_points = [reference_points]
  65. # main process
  66. output = tgt
  67. for layer_id, layer in enumerate(self.decoder_layers):
  68. # Conditional query
  69. query_sine_embed = self.query_sine_embed(num_feats, reference_points)
  70. query_pos = self.ref_point_head(query_sine_embed) # [B, N, C]
  71. # Decoder
  72. output = layer(
  73. # input for decoder
  74. tgt = output,
  75. tgt_query_pos = query_pos,
  76. # input from encoder
  77. memory = memory,
  78. memory_pos = memory_pos,
  79. )
  80. # Iter update
  81. if self.bbox_embed is not None:
  82. delta_unsig = self.bbox_embed[layer_id](output)
  83. outputs_unsig = delta_unsig + self.inverse_sigmoid(reference_points)
  84. new_reference_points = outputs_unsig.sigmoid()
  85. reference_points = new_reference_points.detach()
  86. ref_points.append(new_reference_points)
  87. intermediate.append(output)
  88. return torch.stack(intermediate), torch.stack(ref_points)
  89. # build detection head
  90. def build_decoder(cfg, in_dim, return_intermediate=False):
  91. decoder = TransformerDecoder(cfg, in_dim, return_intermediate=return_intermediate)
  92. return decoder