build.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import os
  2. import torch
  3. from .vit import build_vit
  4. from .vit_mae import build_vit_mae
  5. from .vit_cls import ViTForImageClassification
  6. def build_vision_transformer(args, model_type='default'):
  7. assert args.model in ['vit_t', 'vit_s', 'vit_b', 'vit_l', 'vit_h'], "Unknown vit model: {}".format(args.model)
  8. # ----------- Masked Image Modeling task -----------
  9. if model_type == 'mae':
  10. model = build_vit_mae(args.model, args.img_size, args.patch_size, args.img_dim, args.mask_ratio)
  11. # ----------- Image Classification task -----------
  12. elif model_type == 'cls':
  13. image_encoder = build_vit(args.model, args.img_size, args.patch_size, args.img_dim)
  14. model = ViTForImageClassification(image_encoder, num_classes=args.num_classes, qkv_bias=True)
  15. load_mae_pretrained(model.encoder, args.pretrained)
  16. # ----------- Vison Backbone -----------
  17. elif model_type == 'default':
  18. model = build_vit(args.model, args.img_size, args.patch_size, args.img_dim)
  19. load_mae_pretrained(model, args.pretrained)
  20. else:
  21. raise NotImplementedError("Unknown model type: {}".format(model_type))
  22. return model
  23. def load_mae_pretrained(model, ckpt=None):
  24. if ckpt is not None:
  25. # check path
  26. if not os.path.exists(ckpt):
  27. print("No pretrained model.")
  28. return model
  29. print('- Loading pretrained from: {}'.format(ckpt))
  30. checkpoint = torch.load(ckpt, map_location='cpu')
  31. # checkpoint state dict
  32. encoder_state_dict = checkpoint.pop("encoder")
  33. # load encoder weight into ViT's encoder
  34. model.load_state_dict(encoder_state_dict)