| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- import torch
- import torch.nn as nn
- from torch import optim
- def build_optimizer(cfg, model, base_lr=0.0, resume=None):
- print('==============================')
- print('Optimizer: {}'.format(cfg['optimizer']))
- print('--momentum: {}'.format(cfg['momentum']))
- print('--weight_decay: {}'.format(cfg['weight_decay']))
- if cfg['optimizer'] == 'sgd':
- optimizer = optim.SGD(model.parameters(),
- lr=base_lr,
- momentum=cfg['momentum'],
- weight_decay=cfg['weight_decay']
- )
- elif cfg['optimizer'] == 'yolov5_sgd':
- pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
- for k, v in model.named_modules():
- if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
- pg2.append(v.bias) # biases
- if isinstance(v, nn.BatchNorm2d) or "bn" in k:
- pg0.append(v.weight) # no decay
- elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
- pg1.append(v.weight) # apply decay
- optimizer = optim.SGD(
- pg0, lr=base_lr, momentum=cfg['momentum'], nesterov=True
- )
- optimizer.add_param_group(
- {"params": pg1, "weight_decay": cfg['weight_decay']}
- ) # add pg1 with weight_decay
- optimizer.add_param_group({"params": pg2})
- elif cfg['optimizer'] == 'adam':
- optimizer = optim.Adam(model.parameters(),
- lr=base_lr,
- weight_decay=cfg['weight_decay'])
-
- elif cfg['optimizer'] == 'adamw':
- optimizer = optim.AdamW(model.parameters(),
- lr=base_lr,
- weight_decay=cfg['weight_decay'])
-
- start_epoch = 0
- if resume is not None:
- print('keep training: ', resume)
- checkpoint = torch.load(resume)
- # checkpoint state dict
- checkpoint_state_dict = checkpoint.pop("optimizer")
- optimizer.load_state_dict(checkpoint_state_dict)
- start_epoch = checkpoint.pop("epoch")
-
-
- return optimizer, start_epoch
|