فهرست منبع

add grad norm

yjh0410 1 سال پیش
والد
کامیت
52589d7659
1فایلهای تغییر یافته به همراه5 افزوده شده و 2 حذف شده
  1. 5 2
      yolo/engine.py

+ 5 - 2
yolo/engine.py

@@ -165,6 +165,7 @@ class YoloTrainer(object):
         metric_logger = MetricLogger(delimiter="  ")
         metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
         metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
+        metric_logger.add_meter('gnorm', SmoothedValue(window_size=1, fmt='{value:1f}'))
         header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
         epoch_size = len(self.train_loader)
         print_freq = 10
@@ -208,17 +209,18 @@ class YoloTrainer(object):
                 # Compute loss
                 loss_dict = self.criterion(outputs=outputs, targets=targets)
                 losses = loss_dict['losses']
-                loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
                 losses /= self.grad_accumulate
+                loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
 
             # Backward
             self.scaler.scale(losses).backward()
+            grad_norm = None
 
             # Optimize
             if (iter_i + 1) % self.grad_accumulate == 0:
                 if self.cfg.clip_max_norm > 0:
                     self.scaler.unscale_(self.optimizer)
-                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
+                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
                 self.scaler.step(self.optimizer)
                 self.scaler.update()
                 self.optimizer.zero_grad()
@@ -231,6 +233,7 @@ class YoloTrainer(object):
             metric_logger.update(**loss_dict_reduced)
             metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
             metric_logger.update(size=img_size)
+            metric_logger.update(grad_norm=grad_norm)
 
             if self.args.debug:
                 print("For debug mode, we only train 1 iteration")