|
|
@@ -118,27 +118,50 @@ def is_parallel(model):
|
|
|
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
|
|
|
|
|
|
|
|
+def de_parallel(model):
|
|
|
+ # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
|
|
|
+ return model.module if is_parallel(model) else model
|
|
|
+
|
|
|
+
|
|
|
+def copy_attr(a, b, include=(), exclude=()):
|
|
|
+ # Copy attributes from b to a, options to only include [...] and to exclude [...]
|
|
|
+ for k, v in b.__dict__.items():
|
|
|
+ if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ setattr(a, k, v)
|
|
|
+
|
|
|
+
|
|
|
# Model EMA
|
|
|
class ModelEMA(object):
|
|
|
- def __init__(self, model, decay=0.9999, updates=0):
|
|
|
- # create EMA
|
|
|
- self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
|
|
|
- self.updates = updates
|
|
|
- self.decay = lambda x: decay * (1 - math.exp(-x / 2000.))
|
|
|
+ """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
|
|
+ Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
|
|
+ For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
|
|
+ # Create EMA
|
|
|
+ self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
|
|
+ self.updates = updates # number of EMA updates
|
|
|
+ self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
|
|
for p in self.ema.parameters():
|
|
|
p.requires_grad_(False)
|
|
|
|
|
|
def update(self, model):
|
|
|
# Update EMA parameters
|
|
|
- with torch.no_grad():
|
|
|
- self.updates += 1
|
|
|
- d = self.decay(self.updates)
|
|
|
-
|
|
|
- msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
|
|
|
- for k, v in self.ema.state_dict().items():
|
|
|
- if v.dtype.is_floating_point:
|
|
|
- v *= d
|
|
|
- v += (1. - d) * msd[k].detach()
|
|
|
+ self.updates += 1
|
|
|
+ d = self.decay(self.updates)
|
|
|
+
|
|
|
+ msd = de_parallel(model).state_dict() # model state_dict
|
|
|
+ for k, v in self.ema.state_dict().items():
|
|
|
+ if v.dtype.is_floating_point: # true for FP16 and FP32
|
|
|
+ v *= d
|
|
|
+ v += (1 - d) * msd[k].detach()
|
|
|
+ # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
|
|
|
+
|
|
|
+ def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
|
|
+ # Update EMA attributes
|
|
|
+ copy_attr(self.ema, model, include, exclude)
|
|
|
|
|
|
|
|
|
class CollateFunc(object):
|