rtdetr_decoder.py 800 B

1234567891011121314151617181920212223242526272829303132333435
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. # ----------------- Dencoder for Detection task -----------------
  5. class DetDecoder(nn.Module):
  6. def __init__(self, ):
  7. super().__init__()
  8. self.backbone = None
  9. self.neck = None
  10. self.fpn = None
  11. def forward(self, x):
  12. return
  13. # ----------------- Dencoder for Segmentation task -----------------
  14. class SegDecoder(nn.Module):
  15. def __init__(self, ):
  16. super().__init__()
  17. # TODO: design seg-decoder
  18. def forward(self, x):
  19. return
  20. # ----------------- Dencoder for Pose estimation task -----------------
  21. class PosDecoder(nn.Module):
  22. def __init__(self, ):
  23. super().__init__()
  24. # TODO: design seg-decoder
  25. def forward(self, x):
  26. return