| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- import sys
- import math
- from utils.misc import MetricLogger, SmoothedValue
- def train_one_epoch(args,
- device,
- model,
- data_loader,
- optimizer,
- epoch,
- lr_scheduler_warmup,
- ):
- model.train(True)
- metric_logger = MetricLogger(delimiter=" ")
- metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
- header = 'Epoch: [{}]'.format(epoch)
- print_freq = 20
- epoch_size = len(data_loader)
- # Train one epoch
- for iter_i, (images, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
- ni = iter_i + epoch * epoch_size
- nw = args.wp_epoch * epoch_size
-
- # Warmup
- if nw > 0 and ni < nw:
- lr_scheduler_warmup(ni, optimizer)
- elif ni == nw:
- print("Warmup stage is over.")
- lr_scheduler_warmup.set_lr(optimizer, args.base_lr)
- # To device
- images = images.to(device, non_blocking=True)
- # Inference
- output = model(images)
- # Compute loss
- loss = output["loss"]
- # Check loss
- loss_value = loss.item()
- if not math.isfinite(loss_value):
- print("Loss is {}, stopping training".format(loss_value))
- sys.exit(1)
- # Backward
- loss.backward()
- # Optimize
- optimizer.step()
- optimizer.zero_grad()
- # Logs
- lr = optimizer.param_groups[0]["lr"]
- metric_logger.update(loss=loss_value)
- metric_logger.update(lr=lr)
- # gather the stats from all processes
- print("Averaged stats: {}".format(metric_logger))
- return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|