lr_scheduler.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import torch
  2. # ------------------------- WarmUp LR Scheduler -------------------------
  3. ## Warmup LR Scheduler
  4. class LinearWarmUpScheduler(object):
  5. def __init__(self, base_lr=0.01, wp_iter=500, warmup_factor=0.00066667):
  6. self.base_lr = base_lr
  7. self.wp_iter = wp_iter
  8. self.warmup_factor = warmup_factor
  9. def set_lr(self, optimizer, lr):
  10. for param_group in optimizer.param_groups:
  11. init_lr = param_group['initial_lr']
  12. ratio = init_lr / self.base_lr
  13. param_group['lr'] = lr * ratio
  14. def __call__(self, iter, optimizer):
  15. # warmup
  16. alpha = iter / self.wp_iter
  17. warmup_factor = self.warmup_factor * (1 - alpha) + alpha
  18. tmp_lr = self.base_lr * warmup_factor
  19. self.set_lr(optimizer, tmp_lr)
  20. ## Build WP LR Scheduler
  21. def build_wp_lr_scheduler(cfg):
  22. print('==============================')
  23. print('WarmUpScheduler: {}'.format(cfg.warmup))
  24. print('--base_lr: {}'.format(cfg.base_lr))
  25. print('--warmup_iters: {} ({})'.format(cfg.warmup_iters, cfg.warmup_iters * cfg.grad_accumulate))
  26. print('--warmup_factor: {}'.format(cfg.warmup_factor))
  27. if cfg.warmup == 'linear':
  28. wp_lr_scheduler = LinearWarmUpScheduler(cfg.base_lr, cfg.warmup_iters, cfg.warmup_factor)
  29. return wp_lr_scheduler
  30. # ------------------------- LR Scheduler -------------------------
  31. def build_lr_scheduler(cfg, optimizer, resume=None):
  32. print('==============================')
  33. print('LR Scheduler: {}'.format(cfg.lr_scheduler))
  34. if cfg.lr_scheduler == 'step':
  35. assert hasattr(cfg, 'lr_epoch')
  36. print('--lr_epoch: {}'.format(cfg.lr_epoch))
  37. lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=cfg.lr_epoch)
  38. elif cfg.lr_scheduler == 'cosine':
  39. pass
  40. if resume is not None and resume.lower() != "none":
  41. print('Load lr scheduler from the checkpoint: ', resume)
  42. checkpoint = torch.load(resume)
  43. # checkpoint state dict
  44. checkpoint_state_dict = checkpoint.pop("lr_scheduler")
  45. lr_scheduler.load_state_dict(checkpoint_state_dict)
  46. return lr_scheduler