build.py 1.4 KB

123456789101112131415161718192021222324252627282930
  1. import torch.nn as nn
  2. from .vit import ImageEncoderViT, ViTForImageClassification
  3. def build_vit(args):
  4. if args.model == "vit_t":
  5. img_encoder = ImageEncoderViT(img_size=args.img_size,
  6. patch_size=args.patch_size,
  7. in_chans=args.img_dim,
  8. patch_embed_dim=192,
  9. depth=12,
  10. num_heads=3,
  11. mlp_ratio=4.0,
  12. act_layer=nn.GELU,
  13. dropout = 0.1)
  14. elif args.model == "vit_s":
  15. img_encoder = ImageEncoderViT(img_size=args.img_size,
  16. patch_size=args.patch_size,
  17. in_chans=args.img_dim,
  18. patch_embed_dim=384,
  19. depth=12,
  20. num_heads=6,
  21. mlp_ratio=4.0,
  22. act_layer=nn.GELU,
  23. dropout = 0.1)
  24. else:
  25. raise NotImplementedError("Unknown vit: {}".format(args.model))
  26. # Build ViT for classification
  27. return ViTForImageClassification(img_encoder, args.num_classes, qkv_bias=True)