warmup_schedule.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # Build warmup scheduler
  2. def build_warmup(cfg, base_lr=0.01, wp_iter=500):
  3. print('==============================')
  4. print('WarmUpScheduler: {}'.format(cfg['warmup']))
  5. print('--base_lr: {}'.format(base_lr))
  6. print('--warmup_factor: {}'.format(cfg['warmup_factor']))
  7. print('--wp_iter: {}'.format(wp_iter))
  8. warmup_scheduler = WarmUpScheduler(name=cfg['warmup'],
  9. base_lr=base_lr,
  10. wp_iter=wp_iter,
  11. warmup_factor=cfg['warmup_factor'])
  12. return warmup_scheduler
  13. # Basic Warmup Scheduler
  14. class WarmUpScheduler(object):
  15. def __init__(self,
  16. name='linear',
  17. base_lr=0.01,
  18. wp_iter=500,
  19. warmup_factor=0.00066667):
  20. self.name = name
  21. self.base_lr = base_lr
  22. self.wp_iter = wp_iter
  23. self.warmup_factor = warmup_factor
  24. def set_lr(self, optimizer, lr):
  25. for param_group in optimizer.param_groups:
  26. param_group['lr'] = lr
  27. def warmup(self, iter, optimizer):
  28. # warmup
  29. assert iter < self.wp_iter
  30. if self.name == 'exp':
  31. tmp_lr = self.base_lr * pow(iter / self.wp_iter, 4)
  32. self.set_lr(optimizer, tmp_lr)
  33. elif self.name == 'linear':
  34. alpha = iter / self.wp_iter
  35. warmup_factor = self.warmup_factor * (1 - alpha) + alpha
  36. tmp_lr = self.base_lr * warmup_factor
  37. self.set_lr(optimizer, tmp_lr)
  38. def __call__(self, iter, optimizer):
  39. self.warmup(iter, optimizer)