rtdetr_compressor.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334
  1. import torch
  2. import torch.nn as nn
  3. from .rtdetr_basic import TRDecoderLayer
  4. # Transformer Decoder Module
  5. class MemoryCompressor(nn.Module):
  6. def __init__(self, cfg, in_dim):
  7. super().__init__()
  8. # -------------------- Basic Parameters ---------------------
  9. self.d_model = in_dim
  10. self.ffn_dim = round(cfg['com_dim_feedforward']*cfg['width'])
  11. self.compressed_vector = nn.Embedding(cfg['num_compressed'], in_dim)
  12. # -------------------- Network Parameters ---------------------
  13. self.compress_layer = TRDecoderLayer(
  14. d_model=in_dim,
  15. dim_feedforward=self.ffn_dim,
  16. num_heads=cfg['com_num_heads'],
  17. dropout=cfg['com_dropout'],
  18. act_type=cfg['com_act']
  19. )
  20. def forward(self, memory, memory_pos):
  21. bs = memory.size(0)
  22. output = self.compressed_vector.weight[None].repeat(bs, 1, 1)
  23. output = self.compress_layer(output, None, memory, memory_pos)
  24. return output
  25. def build_compressor(cfg, in_dim):
  26. return MemoryCompressor(cfg, in_dim)