Prechádzať zdrojové kódy

use MetricLogger for RTCTrainer

yjh0410 1 rok pred
rodič
commit
4fc948a740
1 zmenil súbory, kde vykonal 36 pridanie a 4 odobranie
  1. 36 4
      engine.py

+ 36 - 4
engine.py

@@ -1679,14 +1679,20 @@ class RTCTrainerDS(object):
             dist.barrier()
 
     def train_one_epoch(self, model):
+        metric_logger = MetricLogger(delimiter="  ")
+        metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
+        metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
+        header = 'Epoch: [{} / {}]'.format(self.epoch, self.args.max_epoch)
+        epoch_size = len(self.train_loader)
+        print_freq = 10
+
         # basic parameters
         epoch_size = len(self.train_loader)
         img_size = self.args.img_size
-        t0 = time.time()
         nw = epoch_size * self.args.wp_epoch
 
         # Train one epoch
-        for iter_i, (images, targets) in enumerate(self.train_loader):
+        for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
             ni = iter_i + self.epoch * epoch_size
             # Warmup
             if ni <= nw:
@@ -1722,6 +1728,12 @@ class RTCTrainerDS(object):
 
                 # TODO: finish the backward + optimize
 
+            # # Update log
+            # metric_logger.update(**loss_dict_reduced)
+            # metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
+            # metric_logger.update(grad_norm=grad_norm)
+            # metric_logger.update(size=img_size)
+
             if self.args.debug:
                 print("For debug mode, we only train 1 iteration")
                 break
@@ -1730,6 +1742,10 @@ class RTCTrainerDS(object):
         if not self.second_stage:
             self.lr_scheduler.step()
 
+        # Gather the stats from all processes
+        metric_logger.synchronize_between_processes()
+        print("Averaged stats:", metric_logger)
+
     def refine_targets(self, targets, min_box_size):
         # rescale targets
         for tgt in targets:
@@ -2006,14 +2022,20 @@ class RTCTrainerDSP(object):
             dist.barrier()
 
     def train_one_epoch(self, model):
+        metric_logger = MetricLogger(delimiter="  ")
+        metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
+        metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
+        header = 'Epoch: [{} / {}]'.format(self.epoch, self.args.max_epoch)
+        epoch_size = len(self.train_loader)
+        print_freq = 10
+
         # basic parameters
         epoch_size = len(self.train_loader)
         img_size = self.args.img_size
-        t0 = time.time()
         nw = epoch_size * self.args.wp_epoch
 
         # Train one epoch
-        for iter_i, (images, targets) in enumerate(self.train_loader):
+        for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
             ni = iter_i + self.epoch * epoch_size
             # Warmup
             if ni <= nw:
@@ -2050,6 +2072,12 @@ class RTCTrainerDSP(object):
                 
                 # TODO: finish the backward + optimize
 
+            # # Update log
+            # metric_logger.update(**loss_dict_reduced)
+            # metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
+            # metric_logger.update(grad_norm=grad_norm)
+            # metric_logger.update(size=img_size)
+
             if self.args.debug:
                 print("For debug mode, we only train 1 iteration")
                 break
@@ -2058,6 +2086,10 @@ class RTCTrainerDSP(object):
         if not self.second_stage:
             self.lr_scheduler.step()
 
+        # Gather the stats from all processes
+        metric_logger.synchronize_between_processes()
+        print("Averaged stats:", metric_logger)
+
     def refine_targets(self, targets, min_box_size):
         # rescale targets
         for tgt in targets: