optimzer.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334
  1. import torch
  2. def build_optimizer(args, model):
  3. print("=================== Optimizer information ===================")
  4. print("Optimizer: ", args.optimizer)
  5. ## learning rate
  6. if args.optimizer == "adamw":
  7. batch_base = 256 if "vit" in args.model else 1024
  8. args.base_lr = args.base_lr / batch_base * args.batch_size
  9. optimizer = torch.optim.AdamW(model.parameters(),
  10. lr=args.base_lr,
  11. weight_decay=args.weight_decay)
  12. print('- base lr: ', args.base_lr)
  13. print('- min lr: ', args.min_lr)
  14. print('- weight_decay: ', args.weight_decay)
  15. elif args.optimizer == "sgd":
  16. batch_base = 128
  17. args.base_lr = args.base_lr / batch_base * args.batch_size
  18. optimizer = torch.optim.SGD(model.parameters(),
  19. lr=args.base_lr,
  20. momentum=0.9,
  21. weight_decay=args.weight_decay)
  22. print('- base lr: ', args.base_lr)
  23. print('- min lr: ', args.min_lr)
  24. print('- momentum: ', 0.9)
  25. print('- weight decay: ', args.weight_decay)
  26. else:
  27. raise NotImplementedError("Unknown optimizer: {}".format(args.optimizer))
  28. print('- min lr: ', args.min_lr)
  29. return optimizer