optimizer.py 976 B

12345678910111213141516171819202122232425
  1. import torch
  2. def build_optimizer(args, model):
  3. ## learning rate
  4. if args.optimizer == "adamw":
  5. args.base_lr = args.base_lr / 256 * args.batch_size
  6. optimizer = torch.optim.AdamW(model.parameters(),
  7. lr=args.base_lr,
  8. weight_decay=args.weight_decay)
  9. elif args.optimizer == "sgd":
  10. args.base_lr = args.base_lr / 256 * args.batch_size
  11. optimizer = torch.optim.SGD(model.parameters(),
  12. lr=args.base_lr,
  13. momentum=0.9,
  14. weight_decay=args.weight_decay)
  15. else:
  16. raise NotImplementedError("Unknown optimizer: {}".format(args.optimizer))
  17. print("=================== Optimizer information ===================")
  18. print("Optimizer: ", args.optimizer)
  19. print('- base lr: ', args.base_lr)
  20. print('- min lr: ', args.min_lr)
  21. return optimizer