lr_scheduler.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import math
  2. import torch
  3. # ------------------------- WarmUp LR Scheduler -------------------------
  4. ## Warmup LR Scheduler
  5. class LinearWarmUpScheduler(object):
  6. def __init__(self, base_lr=0.01, wp_iter=500, warmup_factor=0.00066667):
  7. self.base_lr = base_lr
  8. self.wp_iter = wp_iter
  9. self.warmup_factor = warmup_factor
  10. def set_lr(self, optimizer, lr, base_lr):
  11. for param_group in optimizer.param_groups:
  12. init_lr = param_group['initial_lr']
  13. ratio = init_lr / base_lr
  14. param_group['lr'] = lr * ratio
  15. def __call__(self, iter, optimizer):
  16. # warmup
  17. alpha = iter / self.wp_iter
  18. warmup_factor = self.warmup_factor * (1 - alpha) + alpha
  19. tmp_lr = self.base_lr * warmup_factor
  20. self.set_lr(optimizer, tmp_lr, self.base_lr)
  21. ## Build WP LR Scheduler
  22. def build_wp_lr_scheduler(cfg, base_lr=0.01):
  23. print('==============================')
  24. print('WarmUpScheduler: {}'.format(cfg['warmup']))
  25. print('--base_lr: {}'.format(base_lr))
  26. print('--warmup_iters: {}'.format(cfg['warmup_iters']))
  27. print('--warmup_factor: {}'.format(cfg['warmup_factor']))
  28. if cfg['warmup'] == 'linear':
  29. wp_lr_scheduler = LinearWarmUpScheduler(base_lr, cfg['warmup_iters'], cfg['warmup_factor'])
  30. return wp_lr_scheduler
  31. # ------------------------- LR Scheduler -------------------------
  32. def build_lr_scheduler(cfg, optimizer, resume=None):
  33. print('==============================')
  34. print('LR Scheduler: {}'.format(cfg['lr_scheduler']))
  35. if cfg['lr_scheduler'] == 'step':
  36. assert 'lr_epoch' in cfg
  37. print('--lr_epoch: {}'.format(cfg['lr_epoch']))
  38. lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=cfg['lr_epoch'])
  39. elif cfg['lr_scheduler'] == 'cosine':
  40. pass
  41. if resume is not None and resume.lower() != "none":
  42. print('keep training: ', resume)
  43. checkpoint = torch.load(resume)
  44. # checkpoint state dict
  45. checkpoint_state_dict = checkpoint.pop("lr_scheduler")
  46. lr_scheduler.load_state_dict(checkpoint_state_dict)
  47. return lr_scheduler
  48. def build_lambda_lr_scheduler(cfg, optimizer, epochs):
  49. """Build learning rate scheduler from cfg file."""
  50. print('==============================')
  51. print('Lr Scheduler: {}'.format(cfg['scheduler']))
  52. # Cosine LR scheduler
  53. if cfg['scheduler'] == 'cosine':
  54. lf = lambda x: ((1 - math.cos(x * math.pi / epochs)) / 2) * (cfg['lrf'] - 1) + 1
  55. # Linear LR scheduler
  56. elif cfg['scheduler'] == 'linear':
  57. lf = lambda x: (1 - x / epochs) * (1.0 - cfg['lrf']) + cfg['lrf']
  58. else:
  59. print('unknown lr scheduler.')
  60. exit(0)
  61. scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  62. return scheduler, lf