| 12345678910111213141516171819202122232425262728 |
- import torch.nn as nn
- from .modules import AttentionPoolingClassifier
- from .vit import ImageEncoderViT
- class ViTForImageClassification(nn.Module):
- def __init__(self,
- image_encoder :ImageEncoderViT,
- num_classes :int = 1000,
- qkv_bias :bool = True,
- ):
- super().__init__()
- # -------- Model parameters --------
- self.encoder = image_encoder
- self.classifier = AttentionPoolingClassifier(
- image_encoder.patch_embed_dim, num_classes, image_encoder.num_heads, qkv_bias, num_queries=1)
- def forward(self, x):
- """
- Inputs:
- x: (torch.Tensor) -> [B, C, H, W]. Input image.
- """
- x = self.encoder(x)
- x, x_cls = self.classifier(x)
- return x
|