| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- # =====================================================================
- # Copyright 2021 RangiLyu. All rights reserved.
- # =====================================================================
- # Modified from: https://github.com/facebookresearch/d2go
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- # Licensed under the Apache License, Version 2.0 (the "License")
- import math
- from copy import deepcopy
- import torch
- import torch.nn as nn
- # Modified from the YOLOv5 project
- class ModelEMA(object):
- def __init__(self, model, ema_decay=0.9999, ema_tau=2000, resume=None):
- # Create EMA
- self.ema = deepcopy(self.de_parallel(model)).eval() # FP32 EMA
- self.updates = 0 # number of EMA updates
- self.decay = lambda x: ema_decay * (1 - math.exp(-x / ema_tau)) # decay exponential ramp (to help early epochs)
- for p in self.ema.parameters():
- p.requires_grad_(False)
- if resume is not None and resume.lower() != "none":
- self.load_resume(resume)
- print("Initialize ModelEMA's updates: {}".format(self.updates))
- def load_resume(self, resume):
- checkpoint = torch.load(resume)
- if 'model_ema' in checkpoint.keys():
- print('--Load ModelEMA state dict from the checkpoint: ', resume)
- model_ema_state_dict = checkpoint["model_ema"]
- self.ema.load_state_dict(model_ema_state_dict)
- if 'ema_updates' in checkpoint.keys():
- print('--Load ModelEMA updates from the checkpoint: ', resume)
- # checkpoint state dict
- self.updates = checkpoint.pop("ema_updates")
- def is_parallel(self, model):
- # Returns True if model is of type DP or DDP
- return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
- def de_parallel(self, model):
- # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
- return model.module if self.is_parallel(model) else model
- def copy_attr(self, a, b, include=(), exclude=()):
- # Copy attributes from b to a, options to only include [...] and to exclude [...]
- for k, v in b.__dict__.items():
- if (len(include) and k not in include) or k.startswith('_') or k in exclude:
- continue
- else:
- setattr(a, k, v)
- def update(self, model):
- # Update EMA parameters
- self.updates += 1
- d = self.decay(self.updates)
- msd = self.de_parallel(model).state_dict() # model state_dict
- for k, v in self.ema.state_dict().items():
- if v.dtype.is_floating_point: # true for FP16 and FP32
- v *= d
- v += (1 - d) * msd[k].detach()
- def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
- # Update EMA attributes
- self.copy_attr(self.ema, model, include, exclude)
|