|
|
@@ -36,11 +36,13 @@ def train_one_epoch(cfg,
|
|
|
for iter_i, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
|
|
ni = iter_i + epoch * epoch_size
|
|
|
# WarmUp
|
|
|
- if ni < cfg.warmup_iters:
|
|
|
- warmup_lr_scheduler(ni, optimizer)
|
|
|
- elif ni == cfg.warmup_iters:
|
|
|
- print('Warmup stage is over.')
|
|
|
- warmup_lr_scheduler.set_lr(optimizer, cfg.base_lr)
|
|
|
+ if ni % cfg.grad_accumulate == 0:
|
|
|
+ ni = ni // cfg.grad_accumulate
|
|
|
+ if ni < cfg.warmup_iters:
|
|
|
+ warmup_lr_scheduler(ni, optimizer)
|
|
|
+ elif ni == cfg.warmup_iters:
|
|
|
+ print('Warmup stage is over.')
|
|
|
+ warmup_lr_scheduler.set_lr(optimizer, cfg.base_lr)
|
|
|
|
|
|
# To device
|
|
|
images, masks = samples
|