Browse Source

modify loss of YOLOv1

yjh0410 2 years ago
parent
commit
2bd4e08c5b
1 changed files with 37 additions and 14 deletions
  1. 37 14
      utils/misc.py

+ 37 - 14
utils/misc.py

@@ -118,27 +118,50 @@ def is_parallel(model):
     return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
 
 
+def de_parallel(model):
+    # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
+    return model.module if is_parallel(model) else model
+
+
+def copy_attr(a, b, include=(), exclude=()):
+    # Copy attributes from b to a, options to only include [...] and to exclude [...]
+    for k, v in b.__dict__.items():
+        if (len(include) and k not in include) or k.startswith('_') or k in exclude:
+            continue
+        else:
+            setattr(a, k, v)
+
+
 # Model EMA
 class ModelEMA(object):
-    def __init__(self, model, decay=0.9999, updates=0):
-        # create EMA
-        self.ema = deepcopy(model.module if is_parallel(model) else model).eval()  # FP32 EMA
-        self.updates = updates
-        self.decay = lambda x: decay * (1 - math.exp(-x / 2000.))
+    """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
+    Keeps a moving average of everything in the model state_dict (parameters and buffers)
+    For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
+    """
+
+    def __init__(self, model, decay=0.9999, tau=2000, updates=0):
+        # Create EMA
+        self.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMA
+        self.updates = updates  # number of EMA updates
+        self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)
         for p in self.ema.parameters():
             p.requires_grad_(False)
 
     def update(self, model):
         # Update EMA parameters
-        with torch.no_grad():
-            self.updates += 1
-            d = self.decay(self.updates)
-
-            msd = model.module.state_dict() if is_parallel(model) else model.state_dict()  # model state_dict
-            for k, v in self.ema.state_dict().items():
-                if v.dtype.is_floating_point:
-                    v *= d
-                    v += (1. - d) * msd[k].detach()
+        self.updates += 1
+        d = self.decay(self.updates)
+
+        msd = de_parallel(model).state_dict()  # model state_dict
+        for k, v in self.ema.state_dict().items():
+            if v.dtype.is_floating_point:  # true for FP16 and FP32
+                v *= d
+                v += (1 - d) * msd[k].detach()
+        # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
+
+    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
+        # Update EMA attributes
+        copy_attr(self.ema, model, include, exclude)
 
 
 class CollateFunc(object):