rtdetr_encoder.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from .basic_modules.backbone import build_backbone
  5. from .basic_modules.fpn import build_fpn
  6. # ----------------- Image Encoder -----------------
  7. class ImageEncoder(nn.Module):
  8. def __init__(self, cfg):
  9. super().__init__()
  10. # ---------------- Basic settings ----------------
  11. ## Basic parameters
  12. self.cfg = cfg
  13. ## Network parameters
  14. self.strides = cfg.out_stride
  15. self.hidden_dim = cfg.hidden_dim
  16. self.num_levels = len(self.strides)
  17. # ---------------- Network settings ----------------
  18. ## Backbone Network
  19. self.backbone = build_backbone(cfg, pretrained=cfg.pretrained)
  20. self.fpn_feat_dims = self.backbone.feat_dims[-3:]
  21. ## Feature Pyramid Network
  22. self.fpn = build_fpn(cfg, self.fpn_feat_dims)
  23. self.fpn_dims = self.fpn.out_dims
  24. def forward(self, x):
  25. pyramid_feats = self.backbone(x)
  26. pyramid_feats = self.fpn(pyramid_feats)
  27. return pyramid_feats