rtrdet_transformer.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. from .rtrdet_basic import get_clones, TREncoderLayer, TRDecoderLayer, MLP
  5. class RTRDetTransformer(nn.Module):
  6. def __init__(self, cfg, num_classes, return_intermediate):
  7. super().__init__()
  8. # -------------------- Basic Parameters ---------------------
  9. self.d_model = round(cfg['d_model']*cfg['width'])
  10. self.num_classes = num_classes
  11. self.num_encoder = cfg['num_encoder']
  12. self.num_deocder = cfg['num_decoder']
  13. self.num_queries = cfg['decoder_num_queries']
  14. self.num_pattern = cfg['decoder_num_pattern']
  15. self.stop_layer_id = cfg['num_decoder'] if cfg['stop_layer_id'] == -1 else cfg['stop_layer_id']
  16. self.return_intermediate = return_intermediate
  17. self.scale = 2 * 3.141592653589793
  18. # -------------------- Network Parameters ---------------------
  19. ## Transformer Encoder
  20. encoder_layer = TREncoderLayer(
  21. self.d_model, cfg['encoder_num_head'], cfg['encoder_mlp_ratio'], cfg['encoder_dropout'], cfg['encoder_act'])
  22. self.encoder_layers = get_clones(encoder_layer, cfg['num_encoder'])
  23. ## Transformer Decoder
  24. decoder_layer = TRDecoderLayer(
  25. self.d_model, cfg['decoder_num_head'], cfg['decoder_mlp_ratio'], cfg['decoder_dropout'], cfg['decoder_act'])
  26. self.decoder_layers = get_clones(decoder_layer, cfg['num_decoder'])
  27. ## Pattern embed
  28. self.pattern = nn.Embedding(cfg['decoder_num_pattern'], self.d_model)
  29. ## Position embed
  30. self.position = nn.Embedding(cfg['decoder_num_queries'], 2)
  31. ## Adaptive PosEmbed
  32. self.adapt_pos2d = nn.Sequential(
  33. nn.Linear(self.d_model, self.d_model),
  34. nn.ReLU(),
  35. nn.Linear(self.d_model, self.d_model),
  36. )
  37. ## Output head
  38. self.class_embed = nn.Linear(self.d_model, self.num_classes)
  39. self.bbox_embed = MLP(self.d_model, self.d_model, 4, 3)
  40. self._reset_parameters()
  41. def _reset_parameters(self):
  42. prior_prob = 0.01
  43. bias_value = -math.log((1 - prior_prob) / prior_prob)
  44. self.class_embed.bias.data = torch.ones(self.num_classes) * bias_value
  45. nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
  46. nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
  47. nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
  48. nn.init.uniform_(self.position.weight.data, 0, 1)
  49. self.class_embed = nn.ModuleList([self.class_embed for _ in range(self.num_deocder)])
  50. self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(self.num_deocder)])
  51. def generate_posembed(self, x, temperature=10000):
  52. hs, ws, num_pos_feats = x.shape[2], x.shape[3], x.shape[1]//2
  53. # generate xy coord mat
  54. y_embed, x_embed = torch.meshgrid(
  55. [torch.arange(1, hs+1, dtype=torch.float32),
  56. torch.arange(1, ws+1, dtype=torch.float32)])
  57. y_embed = y_embed / (hs + 1e-6) * self.scale
  58. x_embed = x_embed / (ws + 1e-6) * self.scale
  59. # [H, W] -> [1, H, W]
  60. y_embed = y_embed[None, :, :].to(x.device)
  61. x_embed = x_embed[None, :, :].to(x.device)
  62. dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=x.device)
  63. dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
  64. dim_t = temperature ** (2 * dim_t_)
  65. pos_x = torch.div(x_embed[..., None], dim_t)
  66. pos_y = torch.div(y_embed[..., None], dim_t)
  67. pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=4).flatten(3)
  68. pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=4).flatten(3)
  69. # [B, H, W, C] -> [B, C, H, W]
  70. pos_embed = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  71. return pos_embed
  72. def pos2posemb2d(self, pos, temperature=10000):
  73. scale = 2 * math.pi
  74. num_pos_feats = self.d_model // 2
  75. pos = pos * scale
  76. dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
  77. dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
  78. dim_t = temperature ** (2 * dim_t_)
  79. pos_x = pos[..., 0, None] / dim_t
  80. pos_y = pos[..., 1, None] / dim_t
  81. pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
  82. pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
  83. posemb = torch.cat((pos_y, pos_x), dim=-1)
  84. return posemb
  85. def inverse_sigmoid(self, x):
  86. x = x.clamp(min=0, max=1)
  87. return torch.log(x.clamp(min=1e-5)/(1 - x).clamp(min=1e-5))
  88. def forward(self, src1=None, src2=None):
  89. """
  90. Input:
  91. src1: C4-level feature -> [B, C4, H4, W4]
  92. sec2: C5-level feature -> [B, C5, H5, W5]
  93. Output:
  94. """
  95. bs, c, h, w = src2.size()
  96. # ------------------------ Transformer Encoder ------------------------
  97. ## Generate pos_embed for src2
  98. pos2d_embed_2 = self.generate_posembed(src2)
  99. ## Reshape: [B, C, H, W] -> [B, N, C], N = HW
  100. src2 = src2.flatten(2).permute(0, 2, 1).contiguous()
  101. pos2d_embed_2 = self.adapt_pos2d(pos2d_embed_2.flatten(2).permute(0, 2, 1).contiguous())
  102. ## Encoder layer
  103. for layer_id, encoder_layer in enumerate(self.encoder_layers):
  104. src2 = encoder_layer(src2, pos2d_embed_2)
  105. ## Feature fusion
  106. src2 = src2.permute(0, 2, 1).reshape(bs, c, h, w)
  107. if src1 is not None:
  108. src1 = src1 + nn.functional.interpolate(src2, scale_factor=2.0)
  109. else:
  110. src1 = src2
  111. # ------------------------ Transformer Decoder ------------------------
  112. ## Generate pos_embed for src1
  113. pos2d_embed_1 = self.generate_posembed(src1)
  114. ## Reshape memory: [B, C, H, W] -> [B, N, C], N = HW
  115. src1 = src1.flatten(2).permute(0, 2, 1).contiguous()
  116. pos2d_embed_1 = self.adapt_pos2d(pos2d_embed_1.flatten(2).permute(0, 2, 1).contiguous())
  117. ## Reshape tgt: [Na, C] -> [1, Na, 1, C] -> [1, Na, Np, C] -> [1, Nq, C], Nq = Na*Np
  118. tgt = self.pattern.weight.reshape(1, self.num_pattern, 1, c).repeat(bs, 1, self.num_queries, 1)
  119. tgt = tgt.reshape(bs, self.num_pattern * self.num_queries, c)
  120. ## Prepare reference points
  121. reference_points = self.position.weight.unsqueeze(0).repeat(bs, self.num_pattern, 1)
  122. ## Decoder layer
  123. output_classes = []
  124. output_coords = []
  125. for layer_id, decoder_layer in enumerate(self.decoder_layers):
  126. ## query embed
  127. query_pos = self.adapt_pos2d(self.pos2posemb2d(reference_points))
  128. tgt = decoder_layer(tgt, query_pos, src1, pos2d_embed_1)
  129. reference = self.inverse_sigmoid(reference_points)
  130. ## class
  131. outputs_class = self.class_embed[layer_id](tgt)
  132. ## bbox
  133. tmp = self.bbox_embed[layer_id](tgt)
  134. tmp[..., :2] += reference
  135. outputs_coord = tmp.sigmoid()
  136. output_classes.append(outputs_class)
  137. output_coords.append(outputs_coord)
  138. if layer_id == self.stop_layer_id:
  139. break
  140. return torch.stack(output_classes), torch.stack(output_coords)
  141. # build detection head
  142. def build_transformer(cfg, num_classes, return_intermediate=False):
  143. if cfg['transformer'] == "RTRDetTransformer":
  144. transoformer = RTRDetTransformer(cfg, num_classes, return_intermediate)
  145. return transoformer