optimzer.py 1.0 KB

123456789101112131415161718192021222324
  1. import torch
  2. def build_optimizer(args, model):
  3. ## learning rate
  4. if args.optimizer == "adamw":
  5. args.base_lr = args.base_lr / args.batch_base * args.batch_size * args.grad_accumulate # auto scale lr
  6. ## optimizer
  7. optimizer = torch.optim.AdamW(model.parameters(), lr=args.base_lr, weight_decay=args.weight_decay)
  8. elif args.optimizer == "sgd":
  9. args.base_lr = args.base_lr / args.batch_base * args.batch_size * args.grad_accumulate # auto scale lr
  10. ## optimizer
  11. optimizer = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
  12. else:
  13. raise NotImplementedError("Unknown optimizer: {}".format(args.optimizer))
  14. print("=================== Optimizer information ===================")
  15. print("Optimizer: ", args.optimizer)
  16. print("- momoentum: ", args.momentum)
  17. print("- weight decay: ", args.weight_decay)
  18. print('- base lr: ', args.base_lr)
  19. print('- min lr: ', args.min_lr)
  20. return optimizer