| 123456789101112131415161718192021222324252627282930313233343536373839404142434445 |
- import torch
- from torch import optim
- def build_optimizer(cfg, model, resume=None):
- print('==============================')
- print('Optimizer: {}'.format(cfg.optimizer))
- print('--base_lr: {}'.format(cfg.base_lr))
- print('--backbone_lr_ratio: {}'.format(cfg.backbone_lr_ratio))
- print('--momentum: {}'.format(cfg.momentum))
- print('--weight_decay: {}'.format(cfg.weight_decay))
- param_dicts = [
- {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
- {
- "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
- "lr": cfg.base_lr * cfg.backbone_lr_ratio,
- },
- ]
- if cfg.optimizer == 'sgd':
- optimizer = optim.SGD(
- params=param_dicts,
- lr=cfg.base_lr,
- momentum=cfg.momentum,
- weight_decay=cfg.weight_decay
- )
-
- elif cfg.optimizer == 'adamw':
- optimizer = optim.AdamW(
- params=param_dicts,
- lr=cfg.base_lr,
- weight_decay=cfg.weight_decay
- )
-
- start_epoch = 0
- if resume is not None:
- print('Load optimzier from the checkpoint: ', resume)
- checkpoint = torch.load(resume)
- # checkpoint state dict
- checkpoint_state_dict = checkpoint.pop("optimizer")
- optimizer.load_state_dict(checkpoint_state_dict)
- start_epoch = checkpoint.pop("epoch") + 1
-
- return optimizer, start_epoch
|