engine.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. Train and eval functions used in main.py
  4. """
  5. import math
  6. import sys
  7. from typing import Iterable
  8. import torch
  9. from utils import distributed_utils
  10. from utils.misc import MetricLogger, SmoothedValue
  11. from utils.vis_tools import vis_data
  12. def train_one_epoch(cfg,
  13. model : torch.nn.Module,
  14. criterion : torch.nn.Module,
  15. data_loader : Iterable,
  16. optimizer : torch.optim.Optimizer,
  17. device : torch.device,
  18. epoch : int,
  19. vis_target : bool,
  20. warmup_lr_scheduler,
  21. class_labels = None,
  22. model_ema = None,
  23. debug :bool = False
  24. ):
  25. model.train()
  26. criterion.train()
  27. metric_logger = MetricLogger(delimiter=" ")
  28. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  29. header = 'Epoch: [{} / {}]'.format(epoch, cfg['max_epoch'])
  30. epoch_size = len(data_loader)
  31. print_freq = 10
  32. iteration = 0
  33. for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
  34. ni = iteration + epoch * epoch_size
  35. # WarmUp
  36. if ni < cfg['warmup_iters']:
  37. warmup_lr_scheduler(ni, optimizer)
  38. elif ni == cfg['warmup_iters']:
  39. print('Warmup stage is over.')
  40. warmup_lr_scheduler.set_lr(optimizer, cfg['base_lr'])
  41. # To device
  42. images, masks = samples
  43. images = images.to(device)
  44. masks = masks.to(device)
  45. targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
  46. # Visualize train targets
  47. if vis_target:
  48. vis_data(images, targets, masks, class_labels, cfg['normalize_coords'], cfg['box_format'])
  49. # Inference
  50. outputs = model(images, masks, targets)
  51. # Compute loss
  52. loss_dict = criterion(outputs, targets)
  53. loss_weight_dict = criterion.weight_dict
  54. losses = sum(loss_dict[k] * loss_weight_dict[k] for k in loss_dict.keys() if k in loss_weight_dict)
  55. # Reduce losses over all GPUs for logging purposes
  56. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  57. loss_dict_reduced_scaled = {k: v * loss_weight_dict[k] for k, v in loss_dict_reduced.items() if k in loss_weight_dict}
  58. losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
  59. loss_value = losses_reduced_scaled.item()
  60. # Check loss
  61. if not math.isfinite(loss_value):
  62. print("Loss is {}, stopping training".format(loss_value))
  63. print(loss_dict_reduced)
  64. sys.exit(1)
  65. # Backward
  66. optimizer.zero_grad()
  67. losses.backward()
  68. if cfg['clip_max_norm'] > 0:
  69. torch.nn.utils.clip_grad_norm_(model.parameters(), cfg['clip_max_norm'])
  70. optimizer.step()
  71. iteration += 1
  72. # ema
  73. if model_ema is not None:
  74. model_ema.update(model)
  75. metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled)
  76. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  77. if debug:
  78. print("For debug mode, we only train the model with 1 iteration.")
  79. break
  80. # gather the stats from all processes
  81. metric_logger.synchronize_between_processes()
  82. print("Averaged stats:", metric_logger)
  83. return {k: meter.global_avg for k, meter in metric_logger.meters.items()}