|
@@ -2,6 +2,7 @@ import torch
|
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
from .rtdetr_encoder import build_encoder
|
|
from .rtdetr_encoder import build_encoder
|
|
|
|
|
+from .rtdetr_compressor import build_compressor
|
|
|
from .rtdetr_decoder import build_decoder
|
|
from .rtdetr_decoder import build_decoder
|
|
|
from .rtdetr_dethead import build_dethead
|
|
from .rtdetr_dethead import build_dethead
|
|
|
|
|
|
|
@@ -32,6 +33,9 @@ class RTDETR(nn.Module):
|
|
|
## Encoder
|
|
## Encoder
|
|
|
self.encoder = build_encoder(cfg, trainable, 'img_encoder')
|
|
self.encoder = build_encoder(cfg, trainable, 'img_encoder')
|
|
|
|
|
|
|
|
|
|
+ ## Compressor
|
|
|
|
|
+ self.compressor = build_compressor(cfg, self.d_model)
|
|
|
|
|
+
|
|
|
## Decoder
|
|
## Decoder
|
|
|
self.decoder = build_decoder(cfg, self.d_model, return_intermediate=aux_loss)
|
|
self.decoder = build_decoder(cfg, self.d_model, return_intermediate=aux_loss)
|
|
|
|
|
|
|
@@ -97,8 +101,11 @@ class RTDETR(nn.Module):
|
|
|
memory = memory.permute(0, 2, 1).contiguous()
|
|
memory = memory.permute(0, 2, 1).contiguous()
|
|
|
memory_pos = memory_pos.permute(0, 2, 1).contiguous()
|
|
memory_pos = memory_pos.permute(0, 2, 1).contiguous()
|
|
|
|
|
|
|
|
|
|
+ # -------------------- Compressor --------------------
|
|
|
|
|
+ compressed_memory = self.compressor(memory, memory_pos)
|
|
|
|
|
+
|
|
|
# -------------------- Decoder --------------------
|
|
# -------------------- Decoder --------------------
|
|
|
- hs, reference = self.decoder(memory, memory_pos)
|
|
|
|
|
|
|
+ hs, reference = self.decoder(compressed_memory, None)
|
|
|
|
|
|
|
|
# -------------------- DetHead --------------------
|
|
# -------------------- DetHead --------------------
|
|
|
out_logits, out_bbox = self.dethead(hs, reference, False)
|
|
out_logits, out_bbox = self.dethead(hs, reference, False)
|
|
@@ -139,8 +146,11 @@ class RTDETR(nn.Module):
|
|
|
memory = memory.permute(0, 2, 1).contiguous()
|
|
memory = memory.permute(0, 2, 1).contiguous()
|
|
|
memory_pos = memory_pos.permute(0, 2, 1).contiguous()
|
|
memory_pos = memory_pos.permute(0, 2, 1).contiguous()
|
|
|
|
|
|
|
|
|
|
+ # -------------------- Compressor --------------------
|
|
|
|
|
+ compressed_memory = self.compressor(memory, memory_pos)
|
|
|
|
|
+
|
|
|
# -------------------- Decoder --------------------
|
|
# -------------------- Decoder --------------------
|
|
|
- hs, reference = self.decoder(memory, memory_pos)
|
|
|
|
|
|
|
+ hs, reference = self.decoder(compressed_memory, None)
|
|
|
|
|
|
|
|
# -------------------- DetHead --------------------
|
|
# -------------------- DetHead --------------------
|
|
|
outputs_class, outputs_coords = self.dethead(hs, reference, True)
|
|
outputs_class, outputs_coords = self.dethead(hs, reference, True)
|