vit_cls.py 822 B

12345678910111213141516171819202122232425262728
  1. import torch.nn as nn
  2. from .modules import AttentionPoolingClassifier
  3. from .vit import ImageEncoderViT
  4. class ViTForImageClassification(nn.Module):
  5. def __init__(self,
  6. image_encoder :ImageEncoderViT,
  7. num_classes :int = 1000,
  8. qkv_bias :bool = True,
  9. ):
  10. super().__init__()
  11. # -------- Model parameters --------
  12. self.encoder = image_encoder
  13. self.classifier = AttentionPoolingClassifier(
  14. image_encoder.patch_embed_dim, num_classes, image_encoder.num_heads, qkv_bias, num_queries=1)
  15. def forward(self, x):
  16. """
  17. Inputs:
  18. x: (torch.Tensor) -> [B, C, H, W]. Input image.
  19. """
  20. x = self.encoder(x)
  21. x, x_cls = self.classifier(x)
  22. return x