| 123456789101112131415161718192021222324 |
- import torch
- def build_optimizer(args, model):
- ## learning rate
- if args.optimizer == "adamw":
- args.base_lr = args.base_lr / args.batch_base * args.batch_size * args.grad_accumulate # auto scale lr
- ## optimizer
- optimizer = torch.optim.AdamW(model.parameters(), lr=args.base_lr, weight_decay=args.weight_decay)
- elif args.optimizer == "sgd":
- args.base_lr = args.base_lr / args.batch_base * args.batch_size * args.grad_accumulate # auto scale lr
- ## optimizer
- optimizer = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
- else:
- raise NotImplementedError("Unknown optimizer: {}".format(args.optimizer))
- print("=================== Optimizer information ===================")
- print("Optimizer: ", args.optimizer)
- print("- momoentum: ", args.momentum)
- print("- weight decay: ", args.weight_decay)
- print('- base lr: ', args.base_lr)
- print('- min lr: ', args.min_lr)
- return optimizer
|