rtrdet_encoder.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import torch
  2. import torch.nn as nn
  3. from .rtrdet_basic import get_clones, TREncoderLayer
  4. # Transformer Encoder Module
  5. class TransformerEncoder(nn.Module):
  6. def __init__(self, cfg):
  7. super().__init__()
  8. # -------------------- Basic Parameters ---------------------
  9. self.d_model = round(cfg['d_model']*cfg['width'])
  10. self.num_encoder = cfg['num_encoder']
  11. # -------------------- Network Parameters ---------------------
  12. encoder_layer = TREncoderLayer(d_model = self.d_model,
  13. num_heads = cfg['encoder_num_head'],
  14. mlp_ratio = cfg['encoder_mlp_ratio'],
  15. dropout = cfg['encoder_dropout'],
  16. act_type = cfg['encoder_act']
  17. )
  18. self.encoder_layers = get_clones(encoder_layer, self.num_encoder)
  19. def forward(self, feat, pos_embed, adapt_pos2d):
  20. # reshape: [B, C, H, W] -> [B, N, C], N = HW
  21. feat = feat.flatten(2).permute(0, 2, 1).contiguous()
  22. pos_embed = adapt_pos2d(pos_embed.flatten(2).permute(0, 2, 1).contiguous())
  23. # Transformer encoder
  24. for encoder in self.encoder_layers:
  25. feat = encoder(feat, pos_embed)
  26. return feat
  27. # build detection head
  28. def build_encoder(cfg):
  29. transformer_encoder = TransformerEncoder(cfg)
  30. return transformer_encoder