optimizer.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import torch
  2. import torch.nn as nn
  3. def build_yolo_optimizer(cfg, model, resume=None):
  4. print('==============================')
  5. print('Optimizer: {}'.format(cfg['optimizer']))
  6. print('--base lr: {}'.format(cfg['lr0']))
  7. print('--momentum: {}'.format(cfg['momentum']))
  8. print('--weight_decay: {}'.format(cfg['weight_decay']))
  9. g = [], [], [] # optimizer parameter groups
  10. bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
  11. for v in model.modules():
  12. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay)
  13. g[2].append(v.bias)
  14. if isinstance(v, bn): # weight (no decay)
  15. g[1].append(v.weight)
  16. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
  17. g[0].append(v.weight)
  18. if cfg['optimizer'] == 'adam':
  19. optimizer = torch.optim.Adam(g[2], lr=cfg['lr0']) # adjust beta1 to momentum
  20. elif cfg['optimizer'] == 'adamw':
  21. optimizer = torch.optim.AdamW(g[2], lr=cfg['lr0'], weight_decay=0.0)
  22. elif cfg['optimizer'] == 'sgd':
  23. optimizer = torch.optim.SGD(g[2], lr=cfg['lr0'], momentum=cfg['momentum'], nesterov=True)
  24. else:
  25. raise NotImplementedError('Optimizer {} not implemented.'.format(cfg['optimizer']))
  26. optimizer.add_param_group({'params': g[0], 'weight_decay': cfg['weight_decay']}) # add g0 with weight_decay
  27. optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
  28. start_epoch = 0
  29. if resume and resume != "None":
  30. print('keep training: ', resume)
  31. checkpoint = torch.load(resume, map_location='cpu')
  32. # checkpoint state dict
  33. checkpoint_state_dict = checkpoint.pop("optimizer")
  34. optimizer.load_state_dict(checkpoint_state_dict)
  35. start_epoch = checkpoint.pop("epoch") + 1
  36. del checkpoint, checkpoint_state_dict
  37. return optimizer, start_epoch
  38. def build_rtdetr_optimizer(cfg, model, resume=None):
  39. print('==============================')
  40. print('Optimizer: {}'.format(cfg['optimizer']))
  41. print('--base lr: {}'.format(cfg['lr0']))
  42. print('--weight_decay: {}'.format(cfg['weight_decay']))
  43. # ------------- Divide model's parameters -------------
  44. param_dicts = [], [], [], [], [], []
  45. norm_names = ["norm"] + ["norm{}".format(i) for i in range(10000)]
  46. for n, p in model.named_parameters():
  47. # Non-Backbone's learnable parameters
  48. if "backbone" not in n and p.requires_grad:
  49. if "bias" == n.split(".")[-1]:
  50. param_dicts[0].append(p) # no weight decay for all layers' bias
  51. else:
  52. if n.split(".")[-2] in norm_names:
  53. param_dicts[1].append(p) # no weight decay for all NormLayers' weight
  54. else:
  55. param_dicts[2].append(p) # weight decay for all Non-NormLayers' weight
  56. # Backbone's learnable parameters
  57. elif "backbone" in n and p.requires_grad:
  58. if "bias" == n.split(".")[-1]:
  59. param_dicts[3].append(p) # no weight decay for all layers' bias
  60. else:
  61. if n.split(".")[-2] in norm_names:
  62. param_dicts[4].append(p) # no weight decay for all NormLayers' weight
  63. else:
  64. param_dicts[5].append(p) # weight decay for all Non-NormLayers' weight
  65. # Non-Backbone's learnable parameters
  66. optimizer = torch.optim.AdamW(param_dicts[0], lr=cfg['lr0'], weight_decay=0.0)
  67. optimizer.add_param_group({"params": param_dicts[1], "weight_decay": 0.0})
  68. optimizer.add_param_group({"params": param_dicts[2], "weight_decay": cfg['weight_decay']})
  69. # Backbone's learnable parameters
  70. backbone_lr = cfg['lr0'] * cfg['backbone_lr_ratio']
  71. optimizer.add_param_group({"params": param_dicts[3], "lr": backbone_lr, "weight_decay": 0.0})
  72. optimizer.add_param_group({"params": param_dicts[4], "lr": backbone_lr, "weight_decay": 0.0})
  73. optimizer.add_param_group({"params": param_dicts[5], "lr": backbone_lr, "weight_decay": cfg['weight_decay']})
  74. start_epoch = 0
  75. if resume and resume != 'None':
  76. print('keep training: ', resume)
  77. checkpoint = torch.load(resume)
  78. # checkpoint state dict
  79. checkpoint_state_dict = checkpoint.pop("optimizer")
  80. optimizer.load_state_dict(checkpoint_state_dict)
  81. start_epoch = checkpoint.pop("epoch") + 1
  82. return optimizer, start_epoch