optimizer.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import torch
  2. def build_optimizer(cfg, model, resume=None):
  3. print('==============================')
  4. print('Optimizer: {}'.format(cfg['optimizer']))
  5. print('--base lr: {}'.format(cfg['lr0']))
  6. print('--momentum: {}'.format(cfg['momentum']))
  7. print('--weight_decay: {}'.format(cfg['weight_decay']))
  8. # ------------- Divide model's parameters -------------
  9. param_dicts = [], [], []
  10. norm_names = ["norm"] + ["norm{}".format(i) for i in range(10000)]
  11. for n, p in model.named_parameters():
  12. if p.requires_grad:
  13. if "bias" == n.split(".")[-1]:
  14. param_dicts[0].append(p) # no weight decay for all layers' bias
  15. else:
  16. if n.split(".")[-2] in norm_names:
  17. param_dicts[1].append(p) # no weight decay for all NormLayers' weight
  18. else:
  19. param_dicts[2].append(p) # weight decay for all Non-NormLayers' weight
  20. # Build optimizer
  21. if cfg['optimizer'] == 'sgd':
  22. optimizer = torch.optim.SGD(param_dicts[0], lr=cfg['lr0'], momentum=cfg['momentum'], weight_decay=0.0)
  23. elif cfg['optimizer'] =='adamw':
  24. optimizer = torch.optim.AdamW(param_dicts[0], lr=cfg['lr0'], weight_decay=0.0)
  25. else:
  26. raise NotImplementedError("Unknown optimizer: {}".format(cfg['optimizer']))
  27. # Add param groups
  28. optimizer.add_param_group({"params": param_dicts[1], "weight_decay": 0.0})
  29. optimizer.add_param_group({"params": param_dicts[2], "weight_decay": cfg['weight_decay']})
  30. start_epoch = 0
  31. if resume and resume != 'None':
  32. checkpoint = torch.load(resume)
  33. # checkpoint state dict
  34. try:
  35. checkpoint_state_dict = checkpoint.pop("optimizer")
  36. print('Load optimizer from the checkpoint: ', resume)
  37. optimizer.load_state_dict(checkpoint_state_dict)
  38. start_epoch = checkpoint.pop("epoch") + 1
  39. del checkpoint, checkpoint_state_dict
  40. except:
  41. print("No optimzier in the given checkpoint.")
  42. return optimizer, start_epoch