optimizer.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import torch
  2. from torch import optim
  3. def build_optimizer(cfg, model, resume=None):
  4. print('==============================')
  5. print('Optimizer: {}'.format(cfg.optimizer))
  6. print('--base_lr: {}'.format(cfg.base_lr))
  7. print('--backbone_lr_ratio: {}'.format(cfg.backbone_lr_ratio))
  8. print('--momentum: {}'.format(cfg.momentum))
  9. print('--weight_decay: {}'.format(cfg.weight_decay))
  10. param_dicts = [
  11. {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
  12. {
  13. "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
  14. "lr": cfg.base_lr * cfg.backbone_lr_ratio,
  15. },
  16. ]
  17. if cfg.optimizer == 'sgd':
  18. optimizer = optim.SGD(
  19. params=param_dicts,
  20. lr=cfg.base_lr,
  21. momentum=cfg.momentum,
  22. weight_decay=cfg.weight_decay
  23. )
  24. elif cfg.optimizer == 'adamw':
  25. optimizer = optim.AdamW(
  26. params=param_dicts,
  27. lr=cfg.base_lr,
  28. weight_decay=cfg.weight_decay
  29. )
  30. start_epoch = 0
  31. if resume is not None:
  32. print('Load optimzier from the checkpoint: ', resume)
  33. checkpoint = torch.load(resume)
  34. # checkpoint state dict
  35. checkpoint_state_dict = checkpoint.pop("optimizer")
  36. optimizer.load_state_dict(checkpoint_state_dict)
  37. start_epoch = checkpoint.pop("epoch") + 1
  38. return optimizer, start_epoch