optimizer.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import torch
  2. def build_yolo_optimizer(cfg, model, resume=None):
  3. print('==============================')
  4. print('Optimizer: {}'.format(cfg.optimizer))
  5. print('--base lr: {}'.format(cfg.base_lr))
  6. print('--min lr: {}'.format(cfg.min_lr))
  7. print('--momentum: {}'.format(cfg.momentum))
  8. print('--weight_decay: {}'.format(cfg.weight_decay))
  9. # ------------- Divide model's parameters -------------
  10. param_dicts = [], [], []
  11. norm_names = ["norm"] + ["norm{}".format(i) for i in range(10000)]
  12. for n, p in model.named_parameters():
  13. if p.requires_grad:
  14. if "bias" == n.split(".")[-1]:
  15. param_dicts[0].append(p) # no weight decay for all layers' bias
  16. else:
  17. if n.split(".")[-2] in norm_names:
  18. param_dicts[1].append(p) # no weight decay for all NormLayers' weight
  19. else:
  20. param_dicts[2].append(p) # weight decay for all Non-NormLayers' weight
  21. # Build optimizer
  22. if cfg.optimizer == 'sgd':
  23. optimizer = torch.optim.SGD(param_dicts[0], lr=cfg.base_lr, momentum=cfg.momentum, weight_decay=0.0)
  24. elif cfg.optimizer =='adamw':
  25. optimizer = torch.optim.AdamW(param_dicts[0], lr=cfg.base_lr, weight_decay=0.0)
  26. else:
  27. raise NotImplementedError("Unknown optimizer: {}".format(cfg.optimizer))
  28. # Add param groups
  29. optimizer.add_param_group({"params": param_dicts[1], "weight_decay": 0.0})
  30. optimizer.add_param_group({"params": param_dicts[2], "weight_decay": cfg.weight_decay})
  31. start_epoch = 0
  32. cfg.best_map = -1.
  33. if resume and resume != 'None':
  34. checkpoint = torch.load(resume)
  35. # checkpoint state dict
  36. try:
  37. checkpoint_state_dict = checkpoint.pop("optimizer")
  38. print('--Load optimizer from the checkpoint: ', resume)
  39. optimizer.load_state_dict(checkpoint_state_dict)
  40. start_epoch = checkpoint.pop("epoch") + 1
  41. if "mAP" in checkpoint:
  42. print('--Load best metric from the checkpoint: ', resume)
  43. best_map = checkpoint["mAP"]
  44. cfg.best_map = best_map
  45. del checkpoint, checkpoint_state_dict
  46. except:
  47. print("No optimzier in the given checkpoint.")
  48. return optimizer, start_epoch
  49. def build_rtdetr_optimizer(cfg, model, resume=None):
  50. print('==============================')
  51. print('Optimizer: {}'.format(cfg.optimizer))
  52. print('--base lr: {}'.format(cfg.base_lr))
  53. print('--weight_decay: {}'.format(cfg.weight_decay))
  54. # ------------- Divide model's parameters -------------
  55. param_dicts = [], [], [], [], [], []
  56. norm_names = ["norm"] + ["norm{}".format(i) for i in range(10000)]
  57. for n, p in model.named_parameters():
  58. # Non-Backbone's learnable parameters
  59. if "backbone" not in n and p.requires_grad:
  60. if "bias" == n.split(".")[-1]:
  61. param_dicts[0].append(p) # no weight decay for all layers' bias
  62. else:
  63. if n.split(".")[-2] in norm_names:
  64. param_dicts[1].append(p) # no weight decay for all NormLayers' weight
  65. else:
  66. param_dicts[2].append(p) # weight decay for all Non-NormLayers' weight
  67. # Backbone's learnable parameters
  68. elif "backbone" in n and p.requires_grad:
  69. if "bias" == n.split(".")[-1]:
  70. param_dicts[3].append(p) # no weight decay for all layers' bias
  71. else:
  72. if n.split(".")[-2] in norm_names:
  73. param_dicts[4].append(p) # no weight decay for all NormLayers' weight
  74. else:
  75. param_dicts[5].append(p) # weight decay for all Non-NormLayers' weight
  76. # Non-Backbone's learnable parameters
  77. optimizer = torch.optim.AdamW(param_dicts[0], lr=cfg.base_lr, weight_decay=0.0)
  78. optimizer.add_param_group({"params": param_dicts[1], "weight_decay": 0.0})
  79. optimizer.add_param_group({"params": param_dicts[2], "weight_decay": cfg.weight_decay})
  80. # Backbone's learnable parameters
  81. backbone_lr = cfg.base_lr * cfg.backbone_lr_ratio
  82. optimizer.add_param_group({"params": param_dicts[3], "lr": backbone_lr, "weight_decay": 0.0})
  83. optimizer.add_param_group({"params": param_dicts[4], "lr": backbone_lr, "weight_decay": 0.0})
  84. optimizer.add_param_group({"params": param_dicts[5], "lr": backbone_lr, "weight_decay": cfg.weight_decay})
  85. start_epoch = 0
  86. if resume and resume != 'None':
  87. print('--Load optimizer from the checkpoint: ', resume)
  88. checkpoint = torch.load(resume)
  89. # checkpoint state dict
  90. checkpoint_state_dict = checkpoint.pop("optimizer")
  91. optimizer.load_state_dict(checkpoint_state_dict)
  92. start_epoch = checkpoint.pop("epoch") + 1
  93. return optimizer, start_epoch