| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- import torch
- # ------------------------- WarmUp LR Scheduler -------------------------
- ## Warmup LR Scheduler
- class LinearWarmUpScheduler(object):
- def __init__(self, base_lr=0.01, wp_iter=500, warmup_factor=0.00066667):
- self.base_lr = base_lr
- self.wp_iter = wp_iter
- self.warmup_factor = warmup_factor
- def set_lr(self, optimizer, lr):
- for param_group in optimizer.param_groups:
- init_lr = param_group['initial_lr']
- ratio = init_lr / self.base_lr
- param_group['lr'] = lr * ratio
- def __call__(self, iter, optimizer):
- # warmup
- alpha = iter / self.wp_iter
- warmup_factor = self.warmup_factor * (1 - alpha) + alpha
- tmp_lr = self.base_lr * warmup_factor
- self.set_lr(optimizer, tmp_lr)
-
- ## Build WP LR Scheduler
- def build_wp_lr_scheduler(cfg, base_lr=0.01):
- print('==============================')
- print('WarmUpScheduler: {}'.format(cfg['warmup']))
- print('--base_lr: {}'.format(base_lr))
- print('--warmup_iters: {}'.format(cfg['warmup_iters']))
- print('--warmup_factor: {}'.format(cfg['warmup_factor']))
- if cfg['warmup'] == 'linear':
- wp_lr_scheduler = LinearWarmUpScheduler(base_lr, cfg['warmup_iters'], cfg['warmup_factor'])
-
- return wp_lr_scheduler
-
- # ------------------------- LR Scheduler -------------------------
- def build_lr_scheduler(cfg, optimizer, resume=None):
- print('==============================')
- print('LR Scheduler: {}'.format(cfg['lr_scheduler']))
- if cfg['lr_scheduler'] == 'step':
- assert 'lr_epoch' in cfg
- print('--lr_epoch: {}'.format(cfg['lr_epoch']))
- lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=cfg['lr_epoch'])
- elif cfg['lr_scheduler'] == 'cosine':
- pass
-
- if resume is not None:
- print('Load lr scheduler from the checkpoint: ', resume)
- checkpoint = torch.load(resume)
- # checkpoint state dict
- checkpoint_state_dict = checkpoint.pop("lr_scheduler")
- lr_scheduler.load_state_dict(checkpoint_state_dict)
- return lr_scheduler
|