ema.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from copy import deepcopy
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. # ---------------------------- Model tools ----------------------------
  6. def is_parallel(model):
  7. # Returns True if model is of type DP or DDP
  8. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  9. ## Model EMA
  10. class ModelEMA(object):
  11. def __init__(self, model, ema_decay=0.9999, ema_tau=2000, updates=0):
  12. # Create EMA
  13. self.ema = deepcopy(self.de_parallel(model)).eval() # FP32 EMA
  14. self.updates = updates # number of EMA updates
  15. self.decay = lambda x: ema_decay * (1 - math.exp(-x / ema_tau)) # decay exponential ramp (to help early epochs)
  16. for p in self.ema.parameters():
  17. p.requires_grad_(False)
  18. def is_parallel(self, model):
  19. # Returns True if model is of type DP or DDP
  20. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  21. def de_parallel(self, model):
  22. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  23. return model.module if self.is_parallel(model) else model
  24. def copy_attr(self, a, b, include=(), exclude=()):
  25. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  26. for k, v in b.__dict__.items():
  27. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  28. continue
  29. else:
  30. setattr(a, k, v)
  31. def update(self, model):
  32. # Update EMA parameters
  33. self.updates += 1
  34. d = self.decay(self.updates)
  35. msd = self.de_parallel(model).state_dict() # model state_dict
  36. for k, v in self.ema.state_dict().items():
  37. if v.dtype.is_floating_point: # true for FP16 and FP32
  38. v *= d
  39. v += (1 - d) * msd[k].detach()
  40. # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
  41. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  42. # Update EMA attributes
  43. self.copy_attr(self.ema, model, include, exclude)