| 12345678910111213141516171819202122232425 |
- import torch
- def build_optimizer(args, model):
- ## learning rate
- if args.optimizer == "adamw":
- args.base_lr = args.base_lr / 256 * args.batch_size
- optimizer = torch.optim.AdamW(model.parameters(),
- lr=args.base_lr,
- weight_decay=args.weight_decay)
- elif args.optimizer == "sgd":
- args.base_lr = args.base_lr / 256 * args.batch_size
- optimizer = torch.optim.SGD(model.parameters(),
- lr=args.base_lr,
- momentum=0.9,
- weight_decay=args.weight_decay)
- else:
- raise NotImplementedError("Unknown optimizer: {}".format(args.optimizer))
- print("=================== Optimizer information ===================")
- print("Optimizer: ", args.optimizer)
- print('- base lr: ', args.base_lr)
- print('- min lr: ', args.min_lr)
- return optimizer
|