| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- from copy import deepcopy
- import math
- import torch
- import torch.nn as nn
- # ---------------------------- Model tools ----------------------------
- def is_parallel(model):
- # Returns True if model is of type DP or DDP
- return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
- ## Model EMA
- class ModelEMA(object):
- def __init__(self, model, ema_decay=0.9999, ema_tau=2000, updates=0):
- # Create EMA
- self.ema = deepcopy(self.de_parallel(model)).eval() # FP32 EMA
- self.updates = updates # number of EMA updates
- self.decay = lambda x: ema_decay * (1 - math.exp(-x / ema_tau)) # decay exponential ramp (to help early epochs)
- for p in self.ema.parameters():
- p.requires_grad_(False)
- def is_parallel(self, model):
- # Returns True if model is of type DP or DDP
- return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
- def de_parallel(self, model):
- # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
- return model.module if self.is_parallel(model) else model
- def copy_attr(self, a, b, include=(), exclude=()):
- # Copy attributes from b to a, options to only include [...] and to exclude [...]
- for k, v in b.__dict__.items():
- if (len(include) and k not in include) or k.startswith('_') or k in exclude:
- continue
- else:
- setattr(a, k, v)
- def update(self, model):
- # Update EMA parameters
- self.updates += 1
- d = self.decay(self.updates)
- msd = self.de_parallel(model).state_dict() # model state_dict
- for k, v in self.ema.state_dict().items():
- if v.dtype.is_floating_point: # true for FP16 and FP32
- v *= d
- v += (1 - d) * msd[k].detach()
- # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
- def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
- # Update EMA attributes
- self.copy_attr(self.ema, model, include, exclude)
|