ema.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # =====================================================================
  2. # Copyright 2021 RangiLyu. All rights reserved.
  3. # =====================================================================
  4. # Modified from: https://github.com/facebookresearch/d2go
  5. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  6. # Licensed under the Apache License, Version 2.0 (the "License")
  7. import math
  8. from copy import deepcopy
  9. import torch
  10. import torch.nn as nn
  11. # Modified from the YOLOv5 project
  12. class ModelEMA(object):
  13. def __init__(self, model, ema_decay=0.9999, ema_tau=2000, resume=None):
  14. # Create EMA
  15. self.ema = deepcopy(self.de_parallel(model)).eval() # FP32 EMA
  16. self.updates = 0 # number of EMA updates
  17. self.decay = lambda x: ema_decay * (1 - math.exp(-x / ema_tau)) # decay exponential ramp (to help early epochs)
  18. for p in self.ema.parameters():
  19. p.requires_grad_(False)
  20. if resume is not None and resume.lower() != "none":
  21. self.load_resume(resume)
  22. print("Initialize ModelEMA's updates: {}".format(self.updates))
  23. def load_resume(self, resume):
  24. checkpoint = torch.load(resume)
  25. if 'model_ema' in checkpoint.keys():
  26. print('--Load ModelEMA state dict from the checkpoint: ', resume)
  27. model_ema_state_dict = checkpoint["model_ema"]
  28. self.ema.load_state_dict(model_ema_state_dict)
  29. if 'ema_updates' in checkpoint.keys():
  30. print('--Load ModelEMA updates from the checkpoint: ', resume)
  31. # checkpoint state dict
  32. self.updates = checkpoint.pop("ema_updates")
  33. def is_parallel(self, model):
  34. # Returns True if model is of type DP or DDP
  35. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  36. def de_parallel(self, model):
  37. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  38. return model.module if self.is_parallel(model) else model
  39. def copy_attr(self, a, b, include=(), exclude=()):
  40. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  41. for k, v in b.__dict__.items():
  42. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  43. continue
  44. else:
  45. setattr(a, k, v)
  46. def update(self, model):
  47. # Update EMA parameters
  48. self.updates += 1
  49. d = self.decay(self.updates)
  50. msd = self.de_parallel(model).state_dict() # model state_dict
  51. for k, v in self.ema.state_dict().items():
  52. if v.dtype.is_floating_point: # true for FP16 and FP32
  53. v *= d
  54. v += (1 - d) * msd[k].detach()
  55. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  56. # Update EMA attributes
  57. self.copy_attr(self.ema, model, include, exclude)