|
|
@@ -30,7 +30,7 @@ class RTDETR(nn.Module):
|
|
|
|
|
|
# --------- Network Parameters ----------
|
|
|
## Encoder
|
|
|
- self.img_encoder = build_encoder(cfg, trainable, 'img_encoder')
|
|
|
+ self.encoder = build_encoder(cfg, trainable, 'img_encoder')
|
|
|
|
|
|
## Decoder
|
|
|
self.decoder = build_decoder(cfg, self.d_model, return_intermediate=aux_loss)
|
|
|
@@ -57,7 +57,7 @@ class RTDETR(nn.Module):
|
|
|
@torch.no_grad()
|
|
|
def inference_single_image(self, x):
|
|
|
# -------------------- Encoder --------------------
|
|
|
- memory, memory_pos = self.img_encoder(x)
|
|
|
+ memory, memory_pos = self.encoder(x)
|
|
|
|
|
|
# -------------------- Decoder --------------------
|
|
|
hs, reference = self.decoder(memory, memory_pos)
|
|
|
@@ -93,7 +93,7 @@ class RTDETR(nn.Module):
|
|
|
return self.inference_single_image(x)
|
|
|
else:
|
|
|
# -------------------- Encoder --------------------
|
|
|
- memory, memory_pos = self.img_encoder(x)
|
|
|
+ memory, memory_pos = self.encoder(x)
|
|
|
|
|
|
# -------------------- Decoder --------------------
|
|
|
hs, reference = self.decoder(memory, memory_pos)
|