rtdetr_decoder.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import torch
  2. import torch.nn as nn
  3. from .rtdetr_basic import get_clones, TRDecoderLayer, MLP
  4. # Transformer Decoder Module
  5. class TransformerDecoder(nn.Module):
  6. def __init__(self, cfg, in_dim, return_intermediate=False):
  7. super().__init__()
  8. self.d_model = in_dim
  9. self.query_dim = 4
  10. self.scale = 2 * 3.141592653589793
  11. self.num_queries = cfg['num_queries']
  12. self.num_deocder_layers = cfg['num_decoder_layers']
  13. self.return_intermediate = return_intermediate
  14. # -------------------- Network Parameters ---------------------
  15. ## Decoder
  16. decoder_layer = TRDecoderLayer(
  17. d_model=in_dim,
  18. num_heads=cfg['de_num_heads'],
  19. dim_feedforward=cfg['de_dim_feedforward'],
  20. dropout=cfg['de_dropout'],
  21. act_type=cfg['de_act']
  22. )
  23. self.decoder_layers = get_clones(decoder_layer, cfg['num_decoder_layers'])
  24. ## RefPoint Embed
  25. self.refpoint_embed = nn.Embedding(cfg['num_queries'], 4)
  26. ## Object Query Embed
  27. self.object_query = nn.Embedding(cfg['num_queries'], in_dim)
  28. nn.init.normal_(self.object_query.weight.data)
  29. ## TODO: Group queries
  30. self.ref_point_head = MLP(self.query_dim // 2 * in_dim, in_dim, in_dim, 2)
  31. self.query_pos_sine_scale = MLP(in_dim, in_dim, in_dim, 2)
  32. self.ref_anchor_head = MLP(in_dim, in_dim, 2, 2)
  33. self.bbox_embed = None
  34. self.class_embed = None
  35. def query_sine_embed(self, num_feats, reference_points):
  36. dim_t = torch.arange(num_feats, dtype=torch.float32, device=reference_points.device)
  37. dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_feats
  38. dim_t = 10000 ** (2 * dim_t_)
  39. x_embed = reference_points[:, :, 0] * self.scale
  40. y_embed = reference_points[:, :, 1] * self.scale
  41. pos_x = x_embed[:, :, None] / dim_t
  42. pos_y = y_embed[:, :, None] / dim_t
  43. pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
  44. pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
  45. w_embed = reference_points[:, :, 2] * self.scale
  46. pos_w = w_embed[:, :, None] / dim_t
  47. pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
  48. h_embed = reference_points[:, :, 3] * self.scale
  49. pos_h = h_embed[:, :, None] / dim_t
  50. pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
  51. query_sine_embed = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
  52. return query_sine_embed
  53. def forward(self, memory, memory_pos):
  54. bs, _, channels = memory.size()
  55. num_feats = channels // 2
  56. # prepare tgt & refpoint
  57. tgt = self.object_query.weight[None].repeat(bs, 1, 1)
  58. refpoint_embed = self.refpoint_embed.weight[None].repeat(bs, 1, 1)
  59. intermediate = []
  60. reference_points = refpoint_embed.sigmoid()
  61. ref_points = [reference_points]
  62. # main process
  63. output = tgt
  64. for layer_id, layer in enumerate(self.decoder_layers):
  65. # query sine embed
  66. query_sine_embed = self.query_sine_embed(num_feats, reference_points)
  67. # conditional query
  68. query_pos = self.ref_point_head(query_sine_embed) # [B, N, C]
  69. # decoder
  70. output = layer(
  71. # input for decoder
  72. tgt = output,
  73. tgt_query_pos = query_pos,
  74. # input from encoder
  75. memory = memory,
  76. memory_pos = memory_pos,
  77. )
  78. # iter update
  79. if self.bbox_embed is not None:
  80. # --------------- Start inverse_sigmoid ---------------
  81. reference_points = reference_points.clamp(min=0, max=1)
  82. reference_points_1 = reference_points.clamp(min=1e-5)
  83. reference_points_2 = (1 - reference_points).clamp(min=1e-5)
  84. reference_before_sigmoid = torch.log(reference_points_1/reference_points_2)
  85. # --------------- End inverse_sigmoid ---------------
  86. delta_unsig = self.bbox_embed[layer_id](output)
  87. outputs_unsig = delta_unsig + reference_before_sigmoid
  88. new_reference_points = outputs_unsig.sigmoid()
  89. reference_points = new_reference_points.detach()
  90. ref_points.append(new_reference_points)
  91. intermediate.append(output)
  92. return torch.stack(intermediate), torch.stack(ref_points)
  93. # build detection head
  94. def build_decoder(cfg, in_dim, return_intermediate=False):
  95. decoder = TransformerDecoder(cfg, in_dim, return_intermediate=return_intermediate)
  96. return decoder