yjh0410 1 жил өмнө
parent
commit
4c5b1a665e
2 өөрчлөгдсөн 17 нэмэгдсэн , 12 устгасан
  1. 15 11
      yolo/engine.py
  2. 2 1
      yolo/train.sh

+ 15 - 11
yolo/engine.py

@@ -62,7 +62,8 @@ class YoloTrainer(object):
         self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
         self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
 
 
         # ---------------------------- Build Optimizer ----------------------------
         # ---------------------------- Build Optimizer ----------------------------
-        cfg.base_lr = cfg.per_image_lr * args.batch_size
+        self.grad_accumulate = max(64 // args.batch_size, 1)
+        cfg.base_lr = cfg.per_image_lr * args.batch_size * self.grad_accumulate
         cfg.min_lr  = cfg.base_lr * cfg.min_lr_ratio
         cfg.min_lr  = cfg.base_lr * cfg.min_lr_ratio
         self.optimizer, self.start_epoch = build_yolo_optimizer(cfg, model, args.resume)
         self.optimizer, self.start_epoch = build_yolo_optimizer(cfg, model, args.resume)
 
 
@@ -168,6 +169,7 @@ class YoloTrainer(object):
         header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
         header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
         epoch_size = len(self.train_loader)
         epoch_size = len(self.train_loader)
         print_freq = 10
         print_freq = 10
+        gnorm = 0.0
 
 
         # basic parameters
         # basic parameters
         epoch_size = len(self.train_loader)
         epoch_size = len(self.train_loader)
@@ -208,28 +210,30 @@ class YoloTrainer(object):
                 # Compute loss
                 # Compute loss
                 loss_dict = self.criterion(outputs=outputs, targets=targets)
                 loss_dict = self.criterion(outputs=outputs, targets=targets)
                 losses = loss_dict['losses']
                 losses = loss_dict['losses']
+                losses /= self.grad_accumulate
                 loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
                 loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
 
 
             # Backward
             # Backward
             self.scaler.scale(losses).backward()
             self.scaler.scale(losses).backward()
 
 
             # Optimize
             # Optimize
-            if self.cfg.clip_max_norm > 0:
-                self.scaler.unscale_(self.optimizer)
-                gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
-            self.scaler.step(self.optimizer)
-            self.scaler.update()
-            self.optimizer.zero_grad()
+            if (iter_i + 1) % self.grad_accumulate == 0:
+                if self.cfg.clip_max_norm > 0:
+                    self.scaler.unscale_(self.optimizer)
+                    gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
+                self.scaler.step(self.optimizer)
+                self.scaler.update()
+                self.optimizer.zero_grad()
 
 
-            # ModelEMA
-            if self.model_ema is not None:
-                self.model_ema.update(model)
+                # ModelEMA
+                if self.model_ema is not None:
+                    self.model_ema.update(model)
 
 
             # Update log
             # Update log
             metric_logger.update(**loss_dict_reduced)
             metric_logger.update(**loss_dict_reduced)
             metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
             metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
-            metric_logger.update(gnorm=gnorm)
             metric_logger.update(size=img_size)
             metric_logger.update(size=img_size)
+            metric_logger.update(gnorm=gnorm)
 
 
             if self.args.debug:
             if self.args.debug:
                 print("For debug mode, we only train 1 iteration")
                 print("For debug mode, we only train 1 iteration")

+ 2 - 1
yolo/train.sh

@@ -16,7 +16,8 @@ if [[ $WORLD_SIZE == 1 ]]; then
             --root ${DATASET_ROOT} \
             --root ${DATASET_ROOT} \
             --model ${MODEL} \
             --model ${MODEL} \
             --batch_size ${BATCH_SIZE} \
             --batch_size ${BATCH_SIZE} \
-            --resume ${RESUME}
+            --resume ${RESUME} \
+            --fp16
 elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
 elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
     python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
     python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
             --cuda \
             --cuda \