optimizer.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import torch
  2. import torch.nn as nn
  3. from torch import optim
  4. def build_optimizer(cfg, model, base_lr=0.0, resume=None):
  5. print('==============================')
  6. print('Optimizer: {}'.format(cfg['optimizer']))
  7. print('--momentum: {}'.format(cfg['momentum']))
  8. print('--weight_decay: {}'.format(cfg['weight_decay']))
  9. if cfg['optimizer'] == 'sgd':
  10. optimizer = optim.SGD(model.parameters(),
  11. lr=base_lr,
  12. momentum=cfg['momentum'],
  13. weight_decay=cfg['weight_decay']
  14. )
  15. elif cfg['optimizer'] == 'yolov5_sgd':
  16. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  17. for k, v in model.named_modules():
  18. if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
  19. pg2.append(v.bias) # biases
  20. if isinstance(v, nn.BatchNorm2d) or "bn" in k:
  21. pg0.append(v.weight) # no decay
  22. elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
  23. pg1.append(v.weight) # apply decay
  24. optimizer = optim.SGD(
  25. pg0, lr=base_lr, momentum=cfg['momentum'], nesterov=True
  26. )
  27. optimizer.add_param_group(
  28. {"params": pg1, "weight_decay": cfg['weight_decay']}
  29. ) # add pg1 with weight_decay
  30. optimizer.add_param_group({"params": pg2})
  31. elif cfg['optimizer'] == 'adam':
  32. optimizer = optim.Adam(model.parameters(),
  33. lr=base_lr,
  34. weight_decay=cfg['weight_decay'])
  35. elif cfg['optimizer'] == 'adamw':
  36. optimizer = optim.AdamW(model.parameters(),
  37. lr=base_lr,
  38. weight_decay=cfg['weight_decay'])
  39. start_epoch = 0
  40. if resume is not None:
  41. print('keep training: ', resume)
  42. checkpoint = torch.load(resume)
  43. # checkpoint state dict
  44. checkpoint_state_dict = checkpoint.pop("optimizer")
  45. optimizer.load_state_dict(checkpoint_state_dict)
  46. start_epoch = checkpoint.pop("epoch")
  47. return optimizer, start_epoch