yjh0410 1 năm trước cách đây
mục cha
commit
aa1a37d1fa
1 tập tin đã thay đổi với 11 bổ sung15 xóa
  1. 11 15
      yolo/engine.py

+ 11 - 15
yolo/engine.py

@@ -62,8 +62,7 @@ class YoloTrainer(object):
         self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
 
         # ---------------------------- Build Optimizer ----------------------------
-        self.grad_accumulate = max(64 // args.batch_size, 1)
-        cfg.base_lr = cfg.per_image_lr * args.batch_size * self.grad_accumulate
+        cfg.base_lr = cfg.per_image_lr * args.batch_size
         cfg.min_lr  = cfg.base_lr * cfg.min_lr_ratio
         self.optimizer, self.start_epoch = build_yolo_optimizer(cfg, model, args.resume)
 
@@ -209,25 +208,22 @@ class YoloTrainer(object):
                 # Compute loss
                 loss_dict = self.criterion(outputs=outputs, targets=targets)
                 losses = loss_dict['losses']
-                losses /= self.grad_accumulate
                 loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
 
             # Backward
             self.scaler.scale(losses).backward()
-            gnorm = get_total_grad_norm(model.parameters())
 
             # Optimize
-            if (iter_i + 1) % self.grad_accumulate == 0:
-                if self.cfg.clip_max_norm > 0:
-                    self.scaler.unscale_(self.optimizer)
-                    gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
-                self.scaler.step(self.optimizer)
-                self.scaler.update()
-                self.optimizer.zero_grad()
-
-                # ModelEMA
-                if self.model_ema is not None:
-                    self.model_ema.update(model)
+            if self.cfg.clip_max_norm > 0:
+                self.scaler.unscale_(self.optimizer)
+                gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
+            self.scaler.step(self.optimizer)
+            self.scaler.update()
+            self.optimizer.zero_grad()
+
+            # ModelEMA
+            if self.model_ema is not None:
+                self.model_ema.update(model)
 
             # Update log
             metric_logger.update(**loss_dict_reduced)