lr_scheduler.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import numpy as np
  2. import math
  3. import torch
  4. # ------------------------- WarmUp LR Scheduler -------------------------
  5. ## Warmup LR Scheduler
  6. class LinearWarmUpLrScheduler(object):
  7. def __init__(self, wp_iter=500, base_lr=0.01, warmup_bias_lr=0.1, warmup_momentum=0.8):
  8. self.wp_iter = wp_iter
  9. self.warmup_momentum = warmup_momentum
  10. self.base_lr = base_lr
  11. self.warmup_bias_lr = warmup_bias_lr
  12. def set_lr(self, optimizer, cur_lr):
  13. for param_group in optimizer.param_groups:
  14. param_group['lr'] = cur_lr
  15. def __call__(self, iter, optimizer):
  16. # warmup
  17. xi = [0, self.wp_iter]
  18. for j, x in enumerate(optimizer.param_groups):
  19. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  20. x['lr'] = np.interp(
  21. iter, xi, [self.warmup_bias_lr if j == 0 else 0.0, x['initial_lr']])
  22. # ------------------------- LR Scheduler -------------------------
  23. def build_lr_scheduler(cfg, optimizer, resume=None):
  24. print('==============================')
  25. print('LR Scheduler: {}'.format(cfg.lr_scheduler))
  26. if cfg.lr_scheduler == "step":
  27. lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.lr_step, gamma=0.1)
  28. elif cfg.lr_scheduler == "cosine":
  29. lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.max_epoch - cfg.warmup_epoch - 1, eta_min=cfg.min_lr)
  30. else:
  31. raise NotImplementedError("Unknown lr scheduler: {}".format(cfg.lr_scheduler))
  32. if resume is not None and resume.lower() != "none":
  33. checkpoint = torch.load(resume)
  34. if 'lr_scheduler' in checkpoint.keys():
  35. print('--Load lr scheduler from the checkpoint: ', resume)
  36. # checkpoint state dict
  37. checkpoint_state_dict = checkpoint.pop("lr_scheduler")
  38. lr_scheduler.load_state_dict(checkpoint_state_dict)
  39. return lr_scheduler
  40. def build_lambda_lr_scheduler(cfg, optimizer, epochs):
  41. """Build learning rate scheduler from cfg file."""
  42. print('==============================')
  43. print('Lr Scheduler: {}'.format(cfg.lr_scheduler))
  44. # Cosine LR scheduler
  45. if cfg.lr_scheduler == 'cosine':
  46. lf = lambda x: ((1 - math.cos(x * math.pi / epochs)) / 2) * (cfg.min_lr_ratio - 1) + 1
  47. # Linear LR scheduler
  48. elif cfg.lr_scheduler == 'linear':
  49. lf = lambda x: (1 - x / epochs) * (1.0 - cfg.min_lr_ratio) + cfg.min_lr_ratio
  50. else:
  51. print('unknown lr scheduler.')
  52. exit(0)
  53. scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  54. return scheduler, lf