Ver código fonte

modify lr warmup with grad accumulate

yjh0410 1 ano atrás
pai
commit
fb0406e314
3 arquivos alterados com 8 adições e 7 exclusões
  1. 7 5
      odlab/engine.py
  2. 0 1
      odlab/train.py
  3. 1 1
      odlab/utils/lr_scheduler.py

+ 7 - 5
odlab/engine.py

@@ -36,11 +36,13 @@ def train_one_epoch(cfg,
     for iter_i, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
         ni = iter_i + epoch * epoch_size
         # WarmUp
-        if ni < cfg.warmup_iters:
-            warmup_lr_scheduler(ni, optimizer)
-        elif ni == cfg.warmup_iters:
-            print('Warmup stage is over.')
-            warmup_lr_scheduler.set_lr(optimizer, cfg.base_lr)
+        if ni % cfg.grad_accumulate == 0:
+            ni = ni // cfg.grad_accumulate
+            if ni < cfg.warmup_iters:
+                warmup_lr_scheduler(ni, optimizer)
+            elif ni == cfg.warmup_iters:
+                print('Warmup stage is over.')
+                warmup_lr_scheduler.set_lr(optimizer, cfg.base_lr)
 
         # To device
         images, masks = samples

+ 0 - 1
odlab/train.py

@@ -146,7 +146,6 @@ def main():
     optimizer, start_epoch = build_optimizer(cfg, model_without_ddp, args.resume)
 
     # ---------------------------- Build LR Scheduler ----------------------------
-    cfg.warmup_iters = cfg.warmup_iters * cfg.grad_accumulate
     wp_lr_scheduler = build_wp_lr_scheduler(cfg)
     lr_scheduler    = build_lr_scheduler(cfg, optimizer, args.resume)
 

+ 1 - 1
odlab/utils/lr_scheduler.py

@@ -27,7 +27,7 @@ def build_wp_lr_scheduler(cfg):
     print('==============================')
     print('WarmUpScheduler: {}'.format(cfg.warmup))
     print('--base_lr: {}'.format(cfg.base_lr))
-    print('--warmup_iters: {}'.format(cfg.warmup_iters))
+    print('--warmup_iters: {} ({})'.format(cfg.warmup_iters, cfg.warmup_iters * cfg.grad_accumulate))
     print('--warmup_factor: {}'.format(cfg.warmup_factor))
 
     if cfg.warmup == 'linear':