optimizer.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import torch
  2. def build_simple_optimizer(cfg, model, resume=None):
  3. print('==============================')
  4. print('Optimizer: {}'.format(cfg.optimizer))
  5. print('--base lr: {}'.format(cfg.base_lr))
  6. print('--min lr: {}'.format(cfg.min_lr))
  7. print('--momentum: {}'.format(cfg.momentum))
  8. print('--weight_decay: {}'.format(cfg.weight_decay))
  9. # ------------- Divide model's parameters -------------
  10. param_dicts = [
  11. {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
  12. {
  13. "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
  14. "lr": cfg.base_lr * cfg.bk_lr_ratio,
  15. },
  16. ]
  17. if cfg.optimizer == 'sgd':
  18. optimizer = torch.optim.SGD(
  19. params=param_dicts,
  20. lr=cfg.base_lr,
  21. momentum=cfg.momentum,
  22. weight_decay=cfg.weight_decay
  23. )
  24. elif cfg.optimizer == 'adamw':
  25. optimizer = torch.optim.AdamW(
  26. params=param_dicts,
  27. lr=cfg.base_lr,
  28. weight_decay=cfg.weight_decay
  29. )
  30. start_epoch = 0
  31. cfg.best_map = -1.
  32. if resume and resume != 'None':
  33. checkpoint = torch.load(resume)
  34. # checkpoint state dict
  35. try:
  36. checkpoint_state_dict = checkpoint.pop("optimizer")
  37. print('--Load optimizer from the checkpoint: ', resume)
  38. optimizer.load_state_dict(checkpoint_state_dict)
  39. start_epoch = checkpoint.pop("epoch") + 1
  40. if "mAP" in checkpoint:
  41. print('--Load best metric from the checkpoint: ', resume)
  42. best_map = checkpoint["mAP"]
  43. cfg.best_map = best_map
  44. del checkpoint, checkpoint_state_dict
  45. except:
  46. print("No optimzier in the given checkpoint.")
  47. return optimizer, start_epoch
  48. def build_yolo_optimizer(cfg, model, resume=None):
  49. print('==============================')
  50. print('Optimizer: {}'.format(cfg.optimizer))
  51. print('--base lr: {}'.format(cfg.base_lr))
  52. print('--min lr: {}'.format(cfg.min_lr))
  53. print('--momentum: {}'.format(cfg.momentum))
  54. print('--weight_decay: {}'.format(cfg.weight_decay))
  55. # ------------- Divide model's parameters -------------
  56. param_dicts = [], [], []
  57. norm_names = ["norm"] + ["norm{}".format(i) for i in range(10000)]
  58. for n, p in model.named_parameters():
  59. if p.requires_grad:
  60. if "bias" == n.split(".")[-1]:
  61. param_dicts[0].append(p) # no weight decay for all layers' bias
  62. else:
  63. if n.split(".")[-2] in norm_names:
  64. param_dicts[1].append(p) # no weight decay for all NormLayers' weight
  65. else:
  66. param_dicts[2].append(p) # weight decay for all Non-NormLayers' weight
  67. # Build optimizer
  68. if cfg.optimizer == 'sgd':
  69. optimizer = torch.optim.SGD(param_dicts[0], lr=cfg.base_lr, momentum=cfg.momentum, weight_decay=0.0)
  70. elif cfg.optimizer =='adamw':
  71. optimizer = torch.optim.AdamW(param_dicts[0], lr=cfg.base_lr, weight_decay=0.0)
  72. else:
  73. raise NotImplementedError("Unknown optimizer: {}".format(cfg.optimizer))
  74. # Add param groups
  75. optimizer.add_param_group({"params": param_dicts[1], "weight_decay": 0.0})
  76. optimizer.add_param_group({"params": param_dicts[2], "weight_decay": cfg.weight_decay})
  77. start_epoch = 0
  78. cfg.best_map = -1.
  79. if resume and resume != 'None':
  80. checkpoint = torch.load(resume)
  81. # checkpoint state dict
  82. try:
  83. checkpoint_state_dict = checkpoint.pop("optimizer")
  84. print('--Load optimizer from the checkpoint: ', resume)
  85. optimizer.load_state_dict(checkpoint_state_dict)
  86. start_epoch = checkpoint.pop("epoch") + 1
  87. if "mAP" in checkpoint:
  88. print('--Load best metric from the checkpoint: ', resume)
  89. best_map = checkpoint["mAP"]
  90. cfg.best_map = best_map
  91. del checkpoint, checkpoint_state_dict
  92. except:
  93. print("No optimzier in the given checkpoint.")
  94. return optimizer, start_epoch
  95. def build_rtdetr_optimizer(cfg, model, resume=None):
  96. print('==============================')
  97. print('Optimizer: {}'.format(cfg.optimizer))
  98. print('--base lr: {}'.format(cfg.base_lr))
  99. print('--weight_decay: {}'.format(cfg.weight_decay))
  100. # ------------- Divide model's parameters -------------
  101. param_dicts = [], [], [], [], [], []
  102. norm_names = ["norm"] + ["norm{}".format(i) for i in range(10000)]
  103. for n, p in model.named_parameters():
  104. # Non-Backbone's learnable parameters
  105. if "backbone" not in n and p.requires_grad:
  106. if "bias" == n.split(".")[-1]:
  107. param_dicts[0].append(p) # no weight decay for all layers' bias
  108. else:
  109. if n.split(".")[-2] in norm_names:
  110. param_dicts[1].append(p) # no weight decay for all NormLayers' weight
  111. else:
  112. param_dicts[2].append(p) # weight decay for all Non-NormLayers' weight
  113. # Backbone's learnable parameters
  114. elif "backbone" in n and p.requires_grad:
  115. if "bias" == n.split(".")[-1]:
  116. param_dicts[3].append(p) # no weight decay for all layers' bias
  117. else:
  118. if n.split(".")[-2] in norm_names:
  119. param_dicts[4].append(p) # no weight decay for all NormLayers' weight
  120. else:
  121. param_dicts[5].append(p) # weight decay for all Non-NormLayers' weight
  122. # Non-Backbone's learnable parameters
  123. optimizer = torch.optim.AdamW(param_dicts[0], lr=cfg.base_lr, weight_decay=0.0)
  124. optimizer.add_param_group({"params": param_dicts[1], "weight_decay": 0.0})
  125. optimizer.add_param_group({"params": param_dicts[2], "weight_decay": cfg.weight_decay})
  126. # Backbone's learnable parameters
  127. backbone_lr = cfg.base_lr * cfg.backbone_lr_ratio
  128. optimizer.add_param_group({"params": param_dicts[3], "lr": backbone_lr, "weight_decay": 0.0})
  129. optimizer.add_param_group({"params": param_dicts[4], "lr": backbone_lr, "weight_decay": 0.0})
  130. optimizer.add_param_group({"params": param_dicts[5], "lr": backbone_lr, "weight_decay": cfg.weight_decay})
  131. start_epoch = 0
  132. if resume and resume != 'None':
  133. print('--Load optimizer from the checkpoint: ', resume)
  134. checkpoint = torch.load(resume)
  135. # checkpoint state dict
  136. checkpoint_state_dict = checkpoint.pop("optimizer")
  137. optimizer.load_state_dict(checkpoint_state_dict)
  138. start_epoch = checkpoint.pop("epoch") + 1
  139. return optimizer, start_epoch