|
|
@@ -373,14 +373,6 @@ def load_weight(model, path_to_ckpt, fuse_cbn=False, fuse_rep_conv=False):
|
|
|
|
|
|
return model
|
|
|
|
|
|
-def get_total_grad_norm(parameters, norm_type=2):
|
|
|
- parameters = list(filter(lambda p: p.grad is not None, parameters))
|
|
|
- norm_type = float(norm_type)
|
|
|
- device = parameters[0].grad.device
|
|
|
- total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
|
|
|
- norm_type)
|
|
|
- return total_norm
|
|
|
-
|
|
|
## Model EMA
|
|
|
class ModelEMA(object):
|
|
|
def __init__(self, model, ema_decay=0.9999, ema_tau=2000, resume=None):
|