|
|
@@ -8,8 +8,9 @@ from .rtdetr_basic import get_clones, TRDecoderLayer, MLP
|
|
|
class TransformerDecoder(nn.Module):
|
|
|
def __init__(self, cfg, in_dim, return_intermediate=False):
|
|
|
super().__init__()
|
|
|
+ # -------------------- Basic Parameters ---------------------
|
|
|
self.d_model = in_dim
|
|
|
- self.query_dim = 4
|
|
|
+ self.query_dim = 4 # For RefPoint head
|
|
|
self.scale = 2 * 3.141592653589793
|
|
|
self.num_queries = cfg['num_queries']
|
|
|
self.num_deocder_layers = cfg['num_decoder_layers']
|
|
|
@@ -82,13 +83,11 @@ class TransformerDecoder(nn.Module):
|
|
|
# main process
|
|
|
output = tgt
|
|
|
for layer_id, layer in enumerate(self.decoder_layers):
|
|
|
- # query sine embed
|
|
|
+ # Conditional query
|
|
|
query_sine_embed = self.query_sine_embed(num_feats, reference_points)
|
|
|
-
|
|
|
- # conditional query
|
|
|
query_pos = self.ref_point_head(query_sine_embed) # [B, N, C]
|
|
|
|
|
|
- # decoder
|
|
|
+ # Decoder
|
|
|
output = layer(
|
|
|
# input for decoder
|
|
|
tgt = output,
|
|
|
@@ -98,7 +97,7 @@ class TransformerDecoder(nn.Module):
|
|
|
memory_pos = memory_pos,
|
|
|
)
|
|
|
|
|
|
- # iter update
|
|
|
+ # Iter update
|
|
|
if self.bbox_embed is not None:
|
|
|
delta_unsig = self.bbox_embed[layer_id](output)
|
|
|
outputs_unsig = delta_unsig + self.inverse_sigmoid(reference_points)
|