| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- import torch
- import torch.nn as nn
- from .rtdetr_basic import get_clones, TRDecoderLayer, MLP
- # Transformer Decoder Module
- class TransformerDecoder(nn.Module):
- def __init__(self, cfg, in_dim, return_intermediate=False):
- super().__init__()
- # -------------------- Basic Parameters ---------------------
- self.d_model = in_dim
- self.query_dim = 4 # For RefPoint head
- self.scale = 2 * 3.141592653589793
- self.num_queries = cfg['num_queries']
- self.num_deocder_layers = cfg['num_decoder_layers']
- self.return_intermediate = return_intermediate
- self.ffn_dim = round(cfg['de_dim_feedforward']*cfg['width'])
- # -------------------- Network Parameters ---------------------
- ## Decoder
- decoder_layer = TRDecoderLayer(
- d_model=in_dim,
- dim_feedforward=self.ffn_dim,
- num_heads=cfg['de_num_heads'],
- dropout=cfg['de_dropout'],
- act_type=cfg['de_act']
- )
- self.decoder_layers = get_clones(decoder_layer, cfg['num_decoder_layers'])
- ## RefPoint Embed
- self.refpoint_embed = nn.Embedding(cfg['num_queries'], 4)
- self.ref_point_head = MLP(self.query_dim // 2 * in_dim, in_dim, in_dim, 2)
- ## Object Query Embed
- self.object_query = nn.Embedding(cfg['num_queries'], in_dim)
- nn.init.normal_(self.object_query.weight.data)
- ## TODO: Group queries
- self.bbox_embed = None
- self.class_embed = None
- def inverse_sigmoid(self, x):
- x = x.clamp(min=0, max=1)
- return torch.log(x.clamp(min=1e-5)/(1 - x).clamp(min=1e-5))
- def query_sine_embed(self, num_feats, reference_points):
- dim_t = torch.arange(num_feats, dtype=torch.float32, device=reference_points.device)
- dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_feats
- dim_t = 10000 ** (2 * dim_t_)
- x_embed = reference_points[:, :, 0] * self.scale
- y_embed = reference_points[:, :, 1] * self.scale
- pos_x = x_embed[:, :, None] / dim_t
- pos_y = y_embed[:, :, None] / dim_t
- pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
- pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
- w_embed = reference_points[:, :, 2] * self.scale
- pos_w = w_embed[:, :, None] / dim_t
- pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
- h_embed = reference_points[:, :, 3] * self.scale
- pos_h = h_embed[:, :, None] / dim_t
- pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
- query_sine_embed = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
- return query_sine_embed
-
- def forward(self, memory, memory_pos):
- bs, _, channels = memory.size()
- num_feats = channels // 2
- # prepare tgt & refpoint
- tgt = self.object_query.weight[None].repeat(bs, 1, 1)
- refpoint_embed = self.refpoint_embed.weight[None].repeat(bs, 1, 1)
- intermediate = []
- reference_points = refpoint_embed.sigmoid()
- ref_points = [reference_points]
- # main process
- output = tgt
- for layer_id, layer in enumerate(self.decoder_layers):
- # Conditional query
- query_sine_embed = self.query_sine_embed(num_feats, reference_points)
- query_pos = self.ref_point_head(query_sine_embed) # [B, N, C]
- # Decoder
- output = layer(
- # input for decoder
- tgt = output,
- tgt_query_pos = query_pos,
- # input from encoder
- memory = memory,
- memory_pos = memory_pos,
- )
- # Iter update
- if self.bbox_embed is not None:
- delta_unsig = self.bbox_embed[layer_id](output)
- outputs_unsig = delta_unsig + self.inverse_sigmoid(reference_points)
- new_reference_points = outputs_unsig.sigmoid()
- reference_points = new_reference_points.detach()
- ref_points.append(new_reference_points)
- intermediate.append(output)
- return torch.stack(intermediate), torch.stack(ref_points)
- # build detection head
- def build_decoder(cfg, in_dim, return_intermediate=False):
- decoder = TransformerDecoder(cfg, in_dim, return_intermediate=return_intermediate)
- return decoder
|