lr_scheduler.py 687 B

123456789101112131415161718192021
  1. import math
  2. import torch
  3. def build_lr_scheduler(cfg, optimizer, epochs):
  4. """Build learning rate scheduler from cfg file."""
  5. print('==============================')
  6. print('Lr Scheduler: {}'.format(cfg['scheduler']))
  7. # Cosine LR scheduler
  8. if cfg['scheduler'] == 'cosine':
  9. lf = lambda x: ((1 - math.cos(x * math.pi / epochs)) / 2) * (cfg['lrf'] - 1) + 1
  10. # Linear LR scheduler
  11. elif cfg['scheduler'] == 'linear':
  12. lf = lambda x: (1 - x / epochs) * (1.0 - cfg['lrf']) + cfg['lrf']
  13. else:
  14. print('unknown lr scheduler.')
  15. exit(0)
  16. scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  17. return scheduler, lf