| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- import numpy as np
- import torch
- # ------------------------- WarmUp LR Scheduler -------------------------
- ## Warmup LR Scheduler
- class LinearWarmUpLrScheduler(object):
- def __init__(self, wp_iter=500, base_lr=0.01, warmup_bias_lr=0.0):
- self.wp_iter = wp_iter
- self.base_lr = base_lr
- self.warmup_bias_lr = warmup_bias_lr
- def set_lr(self, optimizer, cur_lr):
- for param_group in optimizer.param_groups:
- param_group['lr'] = cur_lr
- def __call__(self, iter, optimizer):
- # warmup
- xi = [0, self.wp_iter]
- for j, x in enumerate(optimizer.param_groups):
- # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
- x['lr'] = np.interp(
- iter, xi, [self.warmup_bias_lr if j == 0 else 0.0, x['initial_lr']])
-
-
- # ------------------------- LR Scheduler -------------------------
- def build_lr_scheduler(cfg, optimizer, resume=None):
- print('==============================')
- print('LR Scheduler: {}'.format(cfg.lr_scheduler))
- if cfg.lr_scheduler == "step":
- lr_step = [cfg.max_epoch // 3, cfg.max_epoch // 3 * 2]
- lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_step, gamma=0.1)
- elif cfg.lr_scheduler == "cosine":
- lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.max_epoch - cfg.warmup_epoch - 1, eta_min=cfg.min_lr)
- else:
- raise NotImplementedError("Unknown lr scheduler: {}".format(cfg.lr_scheduler))
-
- if resume is not None and resume.lower() != "none":
- checkpoint = torch.load(resume)
- if 'lr_scheduler' in checkpoint.keys():
- print('--Load lr scheduler from the checkpoint: ', resume)
- # checkpoint state dict
- checkpoint_state_dict = checkpoint.pop("lr_scheduler")
- lr_scheduler.load_state_dict(checkpoint_state_dict)
- return lr_scheduler
|