lr_scheduler.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import math
  2. import torch
  3. # ------------------------- WarmUp LR Scheduler -------------------------
  4. ## Warmup LR Scheduler
  5. class LinearWarmUpLrScheduler(object):
  6. def __init__(self, base_lr=0.01, wp_iter=500):
  7. self.base_lr = base_lr
  8. self.wp_iter = wp_iter
  9. self.warmup_factor = 0.00066667
  10. def set_lr(self, optimizer, cur_lr):
  11. for param_group in optimizer.param_groups:
  12. init_lr = param_group['initial_lr']
  13. ratio = init_lr / self.base_lr
  14. param_group['lr'] = cur_lr * ratio
  15. def __call__(self, iter, optimizer):
  16. # warmup
  17. assert iter < self.wp_iter
  18. alpha = iter / self.wp_iter
  19. warmup_factor = self.warmup_factor * (1 - alpha) + alpha
  20. tmp_lr = self.base_lr * warmup_factor
  21. self.set_lr(optimizer, tmp_lr)
  22. # ------------------------- LR Scheduler -------------------------
  23. def build_lr_scheduler(cfg, optimizer, resume=None):
  24. print('==============================')
  25. print('LR Scheduler: {}'.format(cfg.lr_scheduler))
  26. if cfg.lr_scheduler == "step":
  27. lr_step = [cfg.max_epoch // 3, cfg.max_epoch // 3 * 2]
  28. lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_step, gamma=0.1)
  29. elif cfg.lr_scheduler == "cosine":
  30. lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.max_epoch - cfg.warmup_epoch - 1, eta_min=cfg.min_lr)
  31. else:
  32. raise NotImplementedError("Unknown lr scheduler: {}".format(cfg.lr_scheduler))
  33. if resume is not None and resume.lower() != "none":
  34. checkpoint = torch.load(resume)
  35. if 'lr_scheduler' in checkpoint.keys():
  36. print('--Load lr scheduler from the checkpoint: ', resume)
  37. # checkpoint state dict
  38. checkpoint_state_dict = checkpoint.pop("lr_scheduler")
  39. lr_scheduler.load_state_dict(checkpoint_state_dict)
  40. return lr_scheduler