optimizer.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import torch
  2. from torch import optim
  3. def build_optimizer(optimizer_cfg, model, param_dicts=None, resume=None):
  4. print('==============================')
  5. print('Optimizer: {}'.format(optimizer_cfg['optimizer']))
  6. print('--base_lr: {}'.format(optimizer_cfg['base_lr']))
  7. print('--backbone_lr_ratio: {}'.format(optimizer_cfg['backbone_lr_ratio']))
  8. print('--momentum: {}'.format(optimizer_cfg['momentum']))
  9. print('--weight_decay: {}'.format(optimizer_cfg['weight_decay']))
  10. if param_dicts is None:
  11. param_dicts = [
  12. {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
  13. {
  14. "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
  15. "lr": optimizer_cfg['base_lr'] * optimizer_cfg['backbone_lr_ratio'],
  16. },
  17. ]
  18. if optimizer_cfg['optimizer'] == 'sgd':
  19. optimizer = optim.SGD(
  20. params=param_dicts,
  21. lr=optimizer_cfg['base_lr'],
  22. momentum=optimizer_cfg['momentum'],
  23. weight_decay=optimizer_cfg['weight_decay']
  24. )
  25. elif optimizer_cfg['optimizer'] == 'adamw':
  26. optimizer = optim.AdamW(
  27. params=param_dicts,
  28. lr=optimizer_cfg['base_lr'],
  29. weight_decay=optimizer_cfg['weight_decay']
  30. )
  31. start_epoch = 0
  32. if resume is not None:
  33. print('keep training: ', resume)
  34. checkpoint = torch.load(resume)
  35. # checkpoint state dict
  36. checkpoint_state_dict = checkpoint.pop("optimizer")
  37. optimizer.load_state_dict(checkpoint_state_dict)
  38. start_epoch = checkpoint.pop("epoch") + 1
  39. return optimizer, start_epoch
  40. def build_detr_optimizer(optimizer_cfg, model, resume=None):
  41. print('==============================')
  42. print('Optimizer: {}'.format(optimizer_cfg['optimizer']))
  43. print('--base_lr: {}'.format(optimizer_cfg['base_lr']))
  44. print('--backbone_lr_ratio: {}'.format(optimizer_cfg['backbone_lr_ratio']))
  45. print('--weight_decay: {}'.format(optimizer_cfg['weight_decay']))
  46. # ------------- Divide model's parameters -------------
  47. param_dicts = [], [], [], [], [], [], []
  48. norm_names = ["norm"] + ["norm{}".format(i) for i in range(10000)]
  49. for n, p in model.named_parameters():
  50. # Non-Backbone's learnable parameters
  51. if "backbone" not in n and p.requires_grad:
  52. if "bias" == n.split(".")[-1]:
  53. param_dicts[0].append(p) # no weight decay for all layers' bias
  54. else:
  55. if n.split(".")[-2] in norm_names:
  56. param_dicts[1].append(p) # no weight decay for all NormLayers' weight
  57. elif "cpb_mlp1" in n.split(".") or "cpb_mlp2" in n.split("."):
  58. param_dicts[2].append(p) # no weight decay for plain-detr cpb_mlp weight
  59. else:
  60. param_dicts[3].append(p) # weight decay for all Non-NormLayers' weight
  61. # Backbone's learnable parameters
  62. elif "backbone" in n and p.requires_grad:
  63. if "bias" == n.split(".")[-1]:
  64. param_dicts[4].append(p) # no weight decay for all layers' bias
  65. else:
  66. if n.split(".")[-2] in norm_names:
  67. param_dicts[5].append(p) # no weight decay for all NormLayers' weight
  68. else:
  69. param_dicts[6].append(p) # weight decay for all Non-NormLayers' weight
  70. # Non-Backbone's learnable parameters
  71. optimizer = torch.optim.AdamW(param_dicts[0], lr=optimizer_cfg['base_lr'], weight_decay=0.0)
  72. optimizer.add_param_group({"params": param_dicts[1], "weight_decay": 0.0})
  73. optimizer.add_param_group({"params": param_dicts[2], "weight_decay": 0.0})
  74. optimizer.add_param_group({"params": param_dicts[3], "weight_decay": optimizer_cfg['weight_decay']})
  75. # Backbone's learnable parameters
  76. backbone_lr = optimizer_cfg['base_lr'] * optimizer_cfg['backbone_lr_ratio']
  77. optimizer.add_param_group({"params": param_dicts[4], "lr": backbone_lr, "weight_decay": 0.0})
  78. optimizer.add_param_group({"params": param_dicts[5], "lr": backbone_lr, "weight_decay": 0.0})
  79. optimizer.add_param_group({"params": param_dicts[6], "lr": backbone_lr, "weight_decay": optimizer_cfg['weight_decay']})
  80. start_epoch = 0
  81. if resume is not None:
  82. print('keep training: ', resume)
  83. checkpoint = torch.load(resume)
  84. # checkpoint state dict
  85. checkpoint_state_dict = checkpoint.pop("optimizer")
  86. optimizer.load_state_dict(checkpoint_state_dict)
  87. start_epoch = checkpoint.pop("epoch") + 1
  88. return optimizer, start_epoch