| 123456789101112131415161718192021222324252627282930313233343536373839404142434445 |
- import os
- import torch
- from .vit import build_vit
- from .vit_mae import build_vit_mae
- from .vit_cls import ViTForImageClassification
- def build_vision_transformer(args, model_type='default'):
- assert args.model in ['vit_t', 'vit_s', 'vit_b', 'vit_l', 'vit_h'], "Unknown vit model: {}".format(args.model)
- # ----------- Masked Image Modeling task -----------
- if model_type == 'mae':
- model = build_vit_mae(args.model, args.img_size, args.patch_size, args.img_dim, args.mask_ratio)
-
- # ----------- Image Classification task -----------
- elif model_type == 'cls':
- image_encoder = build_vit(args.model, args.img_size, args.patch_size, args.img_dim)
- model = ViTForImageClassification(image_encoder, num_classes=args.num_classes, qkv_bias=True)
- load_mae_pretrained(model.encoder, args.pretrained)
- # ----------- Vison Backbone -----------
- elif model_type == 'default':
- model = build_vit(args.model, args.img_size, args.patch_size, args.img_dim)
- load_mae_pretrained(model, args.pretrained)
-
- else:
- raise NotImplementedError("Unknown model type: {}".format(model_type))
-
- return model
- def load_mae_pretrained(model, ckpt=None):
- if ckpt is not None:
- # check path
- if not os.path.exists(ckpt):
- print("No pretrained model.")
- return model
- print('- Loading pretrained from: {}'.format(ckpt))
- checkpoint = torch.load(ckpt, map_location='cpu')
- # checkpoint state dict
- encoder_state_dict = checkpoint.pop("encoder")
- # load encoder weight into ViT's encoder
- model.load_state_dict(encoder_state_dict)
|