rtrdet_decoder.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import torch
  2. import torch.nn as nn
  3. import math
  4. from .rtrdet_basic import get_clones, TRDecoderLayer, MLP
  5. # Transformer Decoder Module
  6. class TransformerDecoder(nn.Module):
  7. def __init__(self, cfg, num_classes, return_intermediate=False):
  8. super().__init__()
  9. # -------------------- Basic Parameters ---------------------
  10. self.d_model = round(cfg['d_model'] * cfg['width'])
  11. self.num_queries = cfg['decoder_num_queries']
  12. self.num_pattern = cfg['decoder_num_pattern']
  13. self.num_deocder = cfg['num_decoder']
  14. self.num_classes = num_classes
  15. self.stop_layer_id = cfg['num_decoder'] if cfg['stop_layer_id'] == -1 else cfg['stop_layer_id']
  16. self.return_intermediate = return_intermediate
  17. self.scale = 2 * 3.141592653589793
  18. # -------------------- Network Parameters ---------------------
  19. ## Decoder
  20. decoder_layer = TRDecoderLayer(d_model = self.d_model,
  21. num_heads = cfg['decoder_num_head'],
  22. mlp_ratio = cfg['decoder_mlp_ratio'],
  23. dropout = cfg['decoder_dropout'],
  24. act_type = cfg['decoder_act']
  25. )
  26. self.decoder_layers = get_clones(decoder_layer, self.num_deocder)
  27. ## Pattern embed
  28. self.pattern = nn.Embedding(self.num_pattern, self.d_model)
  29. ## Spatial embed
  30. self.position = nn.Embedding(self.num_queries, 2)
  31. ## Output head
  32. self.class_embed = nn.Linear(self.d_model, self.num_classes)
  33. self.bbox_embed = MLP(self.d_model, self.d_model, 4, 3)
  34. # Adaptive pos_embed
  35. self.adapt_pos2d = nn.Sequential(
  36. nn.Linear(self.d_model, self.d_model),
  37. nn.ReLU(),
  38. nn.Linear(self.d_model, self.d_model),
  39. )
  40. self._reset_parameters()
  41. def _reset_parameters(self):
  42. prior_prob = 0.01
  43. bias_value = -math.log((1 - prior_prob) / prior_prob)
  44. self.class_embed.bias.data = torch.ones(self.num_classes) * bias_value
  45. nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
  46. nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
  47. nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
  48. nn.init.uniform_(self.position.weight.data, 0, 1)
  49. self.class_embed = nn.ModuleList([self.class_embed for _ in range(self.num_deocder)])
  50. self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(self.num_deocder)])
  51. def pos2posemb2d(self, pos, temperature=10000):
  52. pos = pos * self.scale
  53. num_pos_feats = self.d_model // 2
  54. dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
  55. dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
  56. pos_x = pos[..., 0, None] / dim_t
  57. pos_y = pos[..., 1, None] / dim_t
  58. pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
  59. pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
  60. posemb = torch.cat((pos_y, pos_x), dim=-1)
  61. return posemb
  62. def forward(self, memory, memory_pos):
  63. # reshape: [B, C, H, W] -> [B, N, C], N = HW
  64. memory = memory.flatten(2).permute(0, 2, 1).contiguous()
  65. memory_pos = memory_pos.flatten(2).permute(0, 2, 1).contiguous()
  66. memory_pos = self.adapt_pos2d(memory_pos)
  67. bs, _, channels = memory.size()
  68. # reshape: [Na, C] -> [1, Na, 1, C] -> [1, Na, Np, C] -> [1, Nq, C], Nq = Na*Np
  69. tgt = self.pattern.weight.reshape(1, self.num_pattern, 1, channels).repeat(bs, 1, self.num_queries, 1)
  70. tgt = tgt.reshape(bs, self.num_pattern * self.num_queries, channels)
  71. # Reference points
  72. reference_points = self.position.weight.unsqueeze(0).repeat(bs, self.num_pattern, 1)
  73. # Decoder
  74. output_classes = []
  75. output_coords = []
  76. for layer_id, layer in enumerate(self.decoder_layers):
  77. # query embed
  78. query_pos = self.adapt_pos2d(self.pos2posemb2d(reference_points))
  79. tgt = layer(tgt, memory, query_pos, memory_pos)
  80. reference = self.inverse_sigmoid(reference_points)
  81. ## class
  82. outputs_class = self.class_embed[layer_id](tgt)
  83. ## bbox
  84. tmp = self.bbox_embed[layer_id](tgt)
  85. tmp[..., :2] += reference
  86. outputs_coord = tmp.sigmoid()
  87. output_classes.append(outputs_class)
  88. output_coords.append(outputs_coord)
  89. if layer_id == self.stop_layer_id:
  90. break
  91. return torch.stack(output_classes), torch.stack(output_coords)
  92. def inverse_sigmoid(self, x):
  93. x = x.clamp(min=0, max=1)
  94. return torch.log(x.clamp(min=1e-5)/(1 - x).clamp(min=1e-5))
  95. # build detection head
  96. def build_decoder(cfg, num_classes, return_intermediate=False):
  97. decoder = TransformerDecoder(cfg, num_classes, return_intermediate)
  98. return decoder