lr_scheduler.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import torch
  2. # Basic Warmup Scheduler
  3. class LinearWarmUpLrScheduler(object):
  4. def __init__(self, base_lr=0.01, wp_iter=500, warmup_factor=0.00066667):
  5. self.base_lr = base_lr
  6. self.wp_iter = wp_iter
  7. self.warmup_factor = warmup_factor
  8. def set_lr(self, optimizer, cur_lr):
  9. for param_group in optimizer.param_groups:
  10. init_lr = param_group['initial_lr']
  11. ratio = init_lr / self.base_lr
  12. param_group['lr'] = cur_lr * ratio
  13. def __call__(self, iter, optimizer):
  14. # warmup
  15. assert iter < self.wp_iter
  16. alpha = iter / self.wp_iter
  17. warmup_factor = self.warmup_factor * (1 - alpha) + alpha
  18. tmp_lr = self.base_lr * warmup_factor
  19. self.set_lr(optimizer, tmp_lr)
  20. def build_lr_scheduler(args, optimizer):
  21. if args.lr_scheduler == "step":
  22. lr_step = [args.max_epoch // 3, args.max_epoch // 3 * 2]
  23. scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_step, gamma=0.1)
  24. elif args.lr_scheduler == "cosine":
  25. scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch - args.wp_epoch - 1, eta_min=args.min_lr)
  26. else:
  27. raise NotImplementedError("Unknown lr scheduler: {}".format(args.lr_scheduler))
  28. print("=================== LR Scheduler information ===================")
  29. print("LR Scheduler: ", args.lr_scheduler)
  30. return scheduler