lr_scheduler.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import numpy as np
  2. import torch
  3. from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR
  4. # ------------------------- WarmUp LR Scheduler -------------------------
  5. ## Warmup LR Scheduler
  6. class LinearWarmUpLrScheduler(object):
  7. def __init__(self, wp_iter=500, base_lr=0.01, warmup_bias_lr=0.0):
  8. self.wp_iter = wp_iter
  9. self.base_lr = base_lr
  10. self.warmup_bias_lr = warmup_bias_lr
  11. def set_lr(self, optimizer, cur_lr):
  12. for param_group in optimizer.param_groups:
  13. param_group['lr'] = cur_lr
  14. def __call__(self, iter, optimizer):
  15. # warmup
  16. xi = [0, self.wp_iter]
  17. for j, x in enumerate(optimizer.param_groups):
  18. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  19. x['lr'] = np.interp(
  20. iter, xi, [self.warmup_bias_lr if j == 0 else 0.0, x['initial_lr']])
  21. # ------------------------- LR Scheduler -------------------------
  22. def build_lr_scheduler(cfg, optimizer, resume=None):
  23. print('==============================')
  24. print('LR Scheduler: {}'.format(cfg.lr_scheduler))
  25. if cfg.lr_scheduler == "step":
  26. lr_step = [cfg.max_epoch // 2, cfg.max_epoch // 3 * 4]
  27. lr_scheduler = MultiStepLR(optimizer, milestones=lr_step, gamma=0.1)
  28. elif cfg.lr_scheduler == "cosine":
  29. if hasattr(cfg, "warmup_epoch"):
  30. total_epochs = cfg.max_epoch - cfg.warmup_epoch - 1
  31. else:
  32. total_epochs = cfg.max_epoch - 1
  33. lr_scheduler = CosineAnnealingLR(optimizer, T_max=total_epochs, eta_min=cfg.min_lr)
  34. else:
  35. raise NotImplementedError("Unknown lr scheduler: {}".format(cfg.lr_scheduler))
  36. if resume is not None and resume.lower() != "none":
  37. checkpoint = torch.load(resume)
  38. if 'lr_scheduler' in checkpoint.keys():
  39. print('--Load lr scheduler from the checkpoint: ', resume)
  40. # checkpoint state dict
  41. checkpoint_state_dict = checkpoint.pop("lr_scheduler")
  42. lr_scheduler.load_state_dict(checkpoint_state_dict)
  43. return lr_scheduler