Sfoglia il codice sorgente

use MetricLogger for RTCTrainer

yjh0410 1 anno fa
parent
commit
ee91a4a3c0
2 ha cambiato i file con 176 aggiunte e 26 eliminazioni
  1. 18 25
      engine.py
  2. 158 1
      utils/misc.py

+ 18 - 25
engine.py

@@ -9,6 +9,7 @@ import random
 # ----------------- Extra Components -----------------
 from utils import distributed_utils
 from utils.misc import ModelEMA, CollateFunc, build_dataloader
+from utils.misc import MetricLogger, SmoothedValue
 from utils.vis_tools import vis_data
 
 # ----------------- Evaluator Components -----------------
@@ -926,14 +927,20 @@ class RTCTrainer(object):
             dist.barrier()
 
     def train_one_epoch(self, model):
+        metric_logger = MetricLogger(delimiter="  ")
+        metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
+        metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
+        header = 'Epoch: [{} / {}]'.format(self.epoch, self.args.max_epoch)
+        epoch_size = len(self.train_loader)
+        print_freq = 10
+
         # basic parameters
         epoch_size = len(self.train_loader)
         img_size = self.args.img_size
-        t0 = time.time()
         nw = epoch_size * self.args.wp_epoch
 
         # Train one epoch
-        for iter_i, (images, targets) in enumerate(self.train_loader):
+        for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
             ni = iter_i + self.epoch * epoch_size
             # Warmup
             if ni <= nw:
@@ -990,29 +997,11 @@ class RTCTrainer(object):
                 if self.model_ema is not None:
                     self.model_ema.update(model)
 
-            # Logs
-            if distributed_utils.is_main_process() and iter_i % 10 == 0:
-                t1 = time.time()
-                cur_lr = [param_group['lr']  for param_group in self.optimizer.param_groups]
-                # basic infor
-                log =  '[Epoch: {}/{}]'.format(self.epoch, self.args.max_epoch)
-                log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
-                log += '[lr: {:.6f}]'.format(cur_lr[2])
-                # loss infor
-                for k in loss_dict_reduced.keys():
-                    loss_val = loss_dict_reduced[k]
-                    if k == 'losses':
-                        loss_val *= self.grad_accumulate
-                    log += '[{}: {:.2f}]'.format(k, loss_val)
-                # other infor
-                log += '[grad_norm: {:.2f}]'.format(grad_norm)
-                log += '[time: {:.2f}]'.format(t1 - t0)
-                log += '[size: {}]'.format(img_size)
-
-                # print log infor
-                print(log, flush=True)
-                
-                t0 = time.time()
+            # Update log
+            metric_logger.update(**loss_dict_reduced)
+            metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
+            metric_logger.update(grad_norm=grad_norm)
+            metric_logger.update(size=img_size)
 
             if self.args.debug:
                 print("For debug mode, we only train 1 iteration")
@@ -1022,6 +1011,10 @@ class RTCTrainer(object):
         if not self.second_stage:
             self.lr_scheduler.step()
 
+        # Gather the stats from all processes
+        metric_logger.synchronize_between_processes()
+        print("Averaged stats:", metric_logger)
+
     def refine_targets(self, targets, min_box_size):
         # rescale targets
         for tgt in targets:

+ 158 - 1
utils/misc.py

@@ -1,14 +1,171 @@
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+import torch.distributed as dist
 from torch.utils.data import DataLoader, DistributedSampler
-import torchvision
 
 import cv2
 import math
+import time
+import datetime
 import numpy as np
 from copy import deepcopy
 from thop import profile
+from collections import defaultdict, deque
+
+from .distributed_utils import is_dist_avail_and_initialized
+
+
+# ---------------------------- Train tools ----------------------------
+class SmoothedValue(object):
+    """Track a series of values and provide access to smoothed values over a
+    window or the global series average.
+    """
+
+    def __init__(self, window_size=20, fmt=None):
+        if fmt is None:
+            fmt = "{median:.4f} ({global_avg:.4f})"
+        self.deque = deque(maxlen=window_size)
+        self.total = 0.0
+        self.count = 0
+        self.fmt = fmt
+
+    def update(self, value, n=1):
+        self.deque.append(value)
+        self.count += n
+        self.total += value * n
+
+    def synchronize_between_processes(self):
+        """
+        Warning: does not synchronize the deque!
+        """
+        if not is_dist_avail_and_initialized():
+            return
+        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+        dist.barrier()
+        dist.all_reduce(t)
+        t = t.tolist()
+        self.count = int(t[0])
+        self.total = t[1]
+
+    @property
+    def median(self):
+        d = torch.tensor(list(self.deque))
+        return d.median().item()
+
+    @property
+    def avg(self):
+        d = torch.tensor(list(self.deque), dtype=torch.float32)
+        return d.mean().item()
+
+    @property
+    def global_avg(self):
+        return self.total / self.count
+
+    @property
+    def max(self):
+        return max(self.deque)
+
+    @property
+    def value(self):
+        return self.deque[-1]
+
+    def __str__(self):
+        return self.fmt.format(
+            median=self.median,
+            avg=self.avg,
+            global_avg=self.global_avg,
+            max=self.max,
+            value=self.value)
+
+class MetricLogger(object):
+    def __init__(self, delimiter="\t"):
+        self.meters = defaultdict(SmoothedValue)
+        self.delimiter = delimiter
+
+    def update(self, **kwargs):
+        for k, v in kwargs.items():
+            if isinstance(v, torch.Tensor):
+                v = v.item()
+            assert isinstance(v, (float, int))
+            self.meters[k].update(v)
+
+    def __getattr__(self, attr):
+        if attr in self.meters:
+            return self.meters[attr]
+        if attr in self.__dict__:
+            return self.__dict__[attr]
+        raise AttributeError("'{}' object has no attribute '{}'".format(
+            type(self).__name__, attr))
+
+    def __str__(self):
+        loss_str = []
+        for name, meter in self.meters.items():
+            loss_str.append(
+                "{}: {}".format(name, str(meter))
+            )
+        return self.delimiter.join(loss_str)
+
+    def synchronize_between_processes(self):
+        for meter in self.meters.values():
+            meter.synchronize_between_processes()
+
+    def add_meter(self, name, meter):
+        self.meters[name] = meter
+
+    def log_every(self, iterable, print_freq, header=None):
+        i = 0
+        if not header:
+            header = ''
+        start_time = time.time()
+        end = time.time()
+        iter_time = SmoothedValue(fmt='{avg:.4f}')
+        data_time = SmoothedValue(fmt='{avg:.4f}')
+        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+        if torch.cuda.is_available():
+            log_msg = self.delimiter.join([
+                header,
+                '[{0' + space_fmt + '}/{1}]',
+                'eta: {eta}',
+                '{meters}',
+                'time: {time}',
+                'data: {data}',
+                'max mem: {memory:.0f}'
+            ])
+        else:
+            log_msg = self.delimiter.join([
+                header,
+                '[{0' + space_fmt + '}/{1}]',
+                'eta: {eta}',
+                '{meters}',
+                'time: {time}',
+                'data: {data}'
+            ])
+        MB = 1024.0 * 1024.0
+        for obj in iterable:
+            data_time.update(time.time() - end)
+            yield obj
+            iter_time.update(time.time() - end)
+            if i % print_freq == 0 or i == len(iterable) - 1:
+                eta_seconds = iter_time.global_avg * (len(iterable) - i)
+                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+                if torch.cuda.is_available():
+                    print(log_msg.format(
+                        i, len(iterable), eta=eta_string,
+                        meters=str(self),
+                        time=str(iter_time), data=str(data_time),
+                        memory=torch.cuda.max_memory_allocated() / MB))
+                else:
+                    print(log_msg.format(
+                        i, len(iterable), eta=eta_string,
+                        meters=str(self),
+                        time=str(iter_time), data=str(data_time)))
+            i += 1
+            end = time.time()
+        total_time = time.time() - start_time
+        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+        print('{} Total time: {} ({:.4f} s / it)'.format(
+            header, total_time_str, total_time / len(iterable)))
 
 
 # ---------------------------- For Dataset ----------------------------