optimizer.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import torch
  2. from torch import optim
  3. def build_optimizer(cfg, model, resume=None):
  4. print('==============================')
  5. print('Optimizer: {}'.format(cfg.optimizer))
  6. print('--base_lr: {}'.format(cfg.base_lr))
  7. print('--backbone_lr_ratio: {}'.format(cfg.bk_lr_ratio))
  8. print('--momentum: {}'.format(cfg.momentum))
  9. print('--weight_decay: {}'.format(cfg.weight_decay))
  10. param_dicts = [
  11. {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
  12. {
  13. "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
  14. "lr": cfg.base_lr * cfg.bk_lr_ratio,
  15. },
  16. ]
  17. if cfg.optimizer == 'sgd':
  18. optimizer = optim.SGD(
  19. params=param_dicts,
  20. lr=cfg.base_lr,
  21. momentum=cfg.momentum,
  22. weight_decay=cfg.weight_decay
  23. )
  24. elif cfg.optimizer == 'adamw':
  25. optimizer = optim.AdamW(
  26. params=param_dicts,
  27. lr=cfg.base_lr,
  28. weight_decay=cfg.weight_decay
  29. )
  30. start_epoch = 0
  31. cfg.best_map = -1.
  32. if resume is not None and resume.lower() != "none":
  33. print('Load optimzier from the checkpoint: ', resume)
  34. checkpoint = torch.load(resume)
  35. # checkpoint state dict
  36. checkpoint_state_dict = checkpoint.pop("optimizer")
  37. optimizer.load_state_dict(checkpoint_state_dict)
  38. start_epoch = checkpoint.pop("epoch") + 1
  39. if "mAP" in checkpoint:
  40. print('--Load best metric from the checkpoint: ', resume)
  41. best_map = checkpoint["mAP"]
  42. cfg.best_map = best_map / 100.0
  43. del checkpoint, checkpoint_state_dict
  44. return optimizer, start_epoch