lr_scheduler.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import torch
  2. # ------------------------- WarmUp LR Scheduler -------------------------
  3. ## Warmup LR Scheduler
  4. class LinearWarmUpScheduler(object):
  5. def __init__(self, base_lr=0.01, wp_iter=500, warmup_factor=0.00066667):
  6. self.base_lr = base_lr
  7. self.wp_iter = wp_iter
  8. self.warmup_factor = warmup_factor
  9. def set_lr(self, optimizer, lr):
  10. for param_group in optimizer.param_groups:
  11. init_lr = param_group['initial_lr']
  12. ratio = init_lr / self.base_lr
  13. param_group['lr'] = lr * ratio
  14. def __call__(self, iter, optimizer):
  15. # warmup
  16. alpha = iter / self.wp_iter
  17. warmup_factor = self.warmup_factor * (1 - alpha) + alpha
  18. tmp_lr = self.base_lr * warmup_factor
  19. self.set_lr(optimizer, tmp_lr)
  20. ## Build WP LR Scheduler
  21. def build_wp_lr_scheduler(cfg, base_lr=0.01):
  22. print('==============================')
  23. print('WarmUpScheduler: {}'.format(cfg['warmup']))
  24. print('--base_lr: {}'.format(base_lr))
  25. print('--warmup_iters: {}'.format(cfg['warmup_iters']))
  26. print('--warmup_factor: {}'.format(cfg['warmup_factor']))
  27. if cfg['warmup'] == 'linear':
  28. wp_lr_scheduler = LinearWarmUpScheduler(base_lr, cfg['warmup_iters'], cfg['warmup_factor'])
  29. return wp_lr_scheduler
  30. # ------------------------- LR Scheduler -------------------------
  31. def build_lr_scheduler(cfg, optimizer, resume=None):
  32. print('==============================')
  33. print('LR Scheduler: {}'.format(cfg['lr_scheduler']))
  34. if cfg['lr_scheduler'] == 'step':
  35. assert 'lr_epoch' in cfg
  36. print('--lr_epoch: {}'.format(cfg['lr_epoch']))
  37. lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=cfg['lr_epoch'])
  38. elif cfg['lr_scheduler'] == 'cosine':
  39. pass
  40. if resume is not None:
  41. print('Load lr scheduler from the checkpoint: ', resume)
  42. checkpoint = torch.load(resume)
  43. # checkpoint state dict
  44. checkpoint_state_dict = checkpoint.pop("lr_scheduler")
  45. lr_scheduler.load_state_dict(checkpoint_state_dict)
  46. return lr_scheduler