engine_pretrain.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import sys
  2. import math
  3. from utils.misc import MetricLogger, SmoothedValue
  4. def train_one_epoch(args,
  5. device,
  6. model,
  7. data_loader,
  8. optimizer,
  9. epoch,
  10. lr_scheduler_warmup,
  11. ):
  12. model.train(True)
  13. metric_logger = MetricLogger(delimiter=" ")
  14. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  15. header = 'Epoch: [{}]'.format(epoch)
  16. print_freq = 20
  17. epoch_size = len(data_loader)
  18. # Train one epoch
  19. for iter_i, (images, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
  20. ni = iter_i + epoch * epoch_size
  21. nw = args.wp_epoch * epoch_size
  22. # Warmup
  23. if nw > 0 and ni < nw:
  24. lr_scheduler_warmup(ni, optimizer)
  25. elif ni == nw:
  26. print("Warmup stage is over.")
  27. lr_scheduler_warmup.set_lr(optimizer, args.base_lr)
  28. # To device
  29. images = images.to(device, non_blocking=True)
  30. # Inference
  31. output = model(images)
  32. # Compute loss
  33. loss = output["loss"]
  34. # Check loss
  35. loss_value = loss.item()
  36. if not math.isfinite(loss_value):
  37. print("Loss is {}, stopping training".format(loss_value))
  38. sys.exit(1)
  39. # Backward
  40. loss.backward()
  41. # Optimize
  42. optimizer.step()
  43. optimizer.zero_grad()
  44. # Logs
  45. lr = optimizer.param_groups[0]["lr"]
  46. metric_logger.update(loss=loss_value)
  47. metric_logger.update(lr=lr)
  48. # gather the stats from all processes
  49. print("Averaged stats: {}".format(metric_logger))
  50. return {k: meter.global_avg for k, meter in metric_logger.meters.items()}