img_encoder.py 1016 B

123456789101112131415161718192021222324252627282930313233343536373839
  1. import torch
  2. import torch.nn as nn
  3. from .cnn_backbone import build_backbone
  4. from .cnn_neck import build_neck
  5. from .cnn_pafpn import build_fpn
  6. # ------------------------ Image Encoder ------------------------
  7. class ImageEncoder(nn.Module):
  8. def __init__(self, cfg, trainable=False) -> None:
  9. super().__init__()
  10. ## Backbone
  11. self.backbone, feats_dim = build_backbone(cfg, cfg['pretrained']*trainable)
  12. ## Encoder
  13. self.encoder = build_neck(cfg, feats_dim[-1], feats_dim[-1])
  14. ## CSFM
  15. self.csfm = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=round(cfg['d_model']*cfg['width']), input_proj=True)
  16. def forward(self, x):
  17. # Backbone
  18. pyramid_feats = self.backbone(x)
  19. # Encoder
  20. pyramid_feats[-1] = self.encoder(pyramid_feats[-1])
  21. # CSFM
  22. pyramid_feats = self.csfm(pyramid_feats)
  23. return pyramid_feats
  24. # build img-encoder
  25. def build_img_encoder(cfg, trainable):
  26. return ImageEncoder(cfg, trainable)