lr_scheduler.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import torch
  2. # ------------------------- WarmUp LR Scheduler -------------------------
  3. ## Warmup LR Scheduler
  4. class LinearWarmUpScheduler(object):
  5. def __init__(self, base_lr=0.01, wp_iter=500, warmup_factor=0.00066667):
  6. self.base_lr = base_lr
  7. self.wp_iter = wp_iter
  8. self.warmup_factor = warmup_factor
  9. def set_lr(self, optimizer, lr):
  10. for param_group in optimizer.param_groups:
  11. param_group['lr'] = lr
  12. def __call__(self, iter, optimizer):
  13. # warmup
  14. alpha = iter / self.wp_iter
  15. warmup_factor = self.warmup_factor * (1 - alpha) + alpha
  16. tmp_lr = self.base_lr * warmup_factor
  17. self.set_lr(optimizer, tmp_lr)
  18. ## Build WP LR Scheduler
  19. def build_wp_lr_scheduler(cfg):
  20. print('==============================')
  21. print('WarmUpScheduler: {}'.format(cfg.warmup))
  22. print('--base_lr: {}'.format(cfg.base_lr))
  23. print('--warmup_iters: {}'.format(cfg.warmup_iters))
  24. print('--warmup_factor: {}'.format(cfg.warmup_factor))
  25. if cfg.warmup == 'linear':
  26. wp_lr_scheduler = LinearWarmUpScheduler(cfg.base_lr, cfg.warmup_iters, cfg.warmup_factor)
  27. return wp_lr_scheduler
  28. # ------------------------- LR Scheduler -------------------------
  29. def build_lr_scheduler(cfg, optimizer, resume=None):
  30. print('==============================')
  31. print('LR Scheduler: {}'.format(cfg.lr_scheduler))
  32. if cfg.lr_scheduler == 'step':
  33. assert 'lr_epoch' in cfg
  34. print('--lr_epoch: {}'.format(cfg.lr_epoch))
  35. lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=cfg.lr_epoch)
  36. elif cfg.lr_scheduler == 'cosine':
  37. pass
  38. if resume is not None:
  39. print('Load lr scheduler from the checkpoint: ', resume)
  40. checkpoint = torch.load(resume)
  41. # checkpoint state dict
  42. checkpoint_state_dict = checkpoint.pop("lr_scheduler")
  43. lr_scheduler.load_state_dict(checkpoint_state_dict)
  44. return lr_scheduler