engine.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import sys
  2. import math
  3. import numpy as np
  4. import torch
  5. from utils.misc import MetricLogger, SmoothedValue
  6. from utils.misc import print_rank_0, all_reduce_mean, accuracy
  7. def train_one_epoch(args,
  8. device,
  9. model,
  10. model_ema,
  11. data_loader,
  12. optimizer,
  13. epoch,
  14. lr_scheduler_warmup,
  15. loss_scaler,
  16. criterion,
  17. local_rank=0,
  18. tblogger=None,
  19. mixup_fn=None):
  20. model.train(True)
  21. metric_logger = MetricLogger(delimiter=" ")
  22. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  23. header = 'Epoch: [{} / {}]'.format(epoch, args.max_epoch)
  24. print_freq = 20
  25. epoch_size = len(data_loader)
  26. optimizer.zero_grad()
  27. # train one epoch
  28. for iter_i, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
  29. ni = iter_i + epoch * epoch_size
  30. nw = args.wp_epoch * epoch_size
  31. # Warmup
  32. if nw > 0 and ni < nw:
  33. lr_scheduler_warmup(ni, optimizer)
  34. elif ni == nw:
  35. print("Warmup stage is over.")
  36. lr_scheduler_warmup.set_lr(optimizer, args.base_lr)
  37. # To device
  38. images = images.to(device, non_blocking=True)
  39. targets = targets.to(device, non_blocking=True)
  40. # Mixup
  41. if mixup_fn is not None:
  42. images, targets = mixup_fn(images, targets)
  43. # Inference
  44. with torch.cuda.amp.autocast():
  45. output = model(images)
  46. loss = criterion(output, targets)
  47. # Check loss
  48. loss_value = loss.item()
  49. if not math.isfinite(loss_value):
  50. print("Loss is {}, stopping training".format(loss_value))
  51. sys.exit(1)
  52. # Backward & Optimize
  53. loss /= args.grad_accumulate
  54. loss_scaler(loss, optimizer, clip_grad=args.max_grad_norm,
  55. parameters=model.parameters(), create_graph=False,
  56. update_grad=(iter_i + 1) % args.grad_accumulate == 0)
  57. if (iter_i + 1) % args.grad_accumulate == 0:
  58. optimizer.zero_grad()
  59. if model_ema is not None:
  60. model_ema.update(model)
  61. if torch.cuda.is_available():
  62. torch.cuda.synchronize()
  63. # Logs
  64. lr = optimizer.param_groups[0]["lr"]
  65. metric_logger.update(loss=loss_value)
  66. metric_logger.update(lr=lr)
  67. loss_value_reduce = all_reduce_mean(loss_value)
  68. if tblogger is not None and (iter_i + 1) % args.grad_accumulate == 0:
  69. """ We use epoch_1000x as the x-axis in tensorboard.
  70. This calibrates different curves when batch size changes.
  71. """
  72. epoch_1000x = int((iter_i / len(data_loader) + epoch) * 1000)
  73. tblogger.add_scalar('loss', loss_value_reduce, epoch_1000x)
  74. tblogger.add_scalar('lr', lr, epoch_1000x)
  75. # gather the stats from all processes
  76. metric_logger.synchronize_between_processes()
  77. print_rank_0("Averaged stats: {}".format(metric_logger), local_rank)
  78. return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
  79. @torch.no_grad()
  80. def evaluate(data_loader, model, device, local_rank):
  81. criterion = torch.nn.CrossEntropyLoss()
  82. metric_logger = MetricLogger(delimiter=" ")
  83. header = 'Test:'
  84. # switch to evaluation mode
  85. model.eval()
  86. for batch in metric_logger.log_every(data_loader, 10, header):
  87. images = batch[0]
  88. target = batch[-1]
  89. images = images.to(device, non_blocking=True)
  90. target = target.to(device, non_blocking=True)
  91. # compute output
  92. with torch.cuda.amp.autocast():
  93. output = model(images)
  94. loss = criterion(output, target)
  95. acc1, acc5 = accuracy(output, target, topk=(1, 5))
  96. batch_size = images.shape[0]
  97. metric_logger.update(loss=loss.item())
  98. metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
  99. metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
  100. # gather the stats from all processes
  101. metric_logger.synchronize_between_processes()
  102. print_rank_0('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
  103. .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss),
  104. local_rank)
  105. return {k: meter.global_avg for k, meter in metric_logger.meters.items()}