rtdetr_encoder.py 429 B

12345678910111213141516171819
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from .basic_modules.backbone import build_backbone
  5. from .basic_modules.pafpn import build_pafpn
  6. # ----------------- Image Encoder -----------------
  7. class ImageEncoder(nn.Module):
  8. def __init__(self, ):
  9. super().__init__()
  10. self.backbone = None
  11. self.neck = None
  12. self.fpn = None
  13. def forward(self, x):
  14. return