|
|
@@ -71,11 +71,11 @@ def train_one_epoch(cfg,
|
|
|
|
|
|
# Backward
|
|
|
losses.backward()
|
|
|
- if cfg.clip_max_norm > 0:
|
|
|
- torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_max_norm)
|
|
|
|
|
|
# Optimize
|
|
|
if (iter_i + 1) % cfg.grad_accumulate == 0:
|
|
|
+ if cfg.clip_max_norm > 0:
|
|
|
+ torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_max_norm)
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
|