engine.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. Train and eval functions used in main.py
  4. """
  5. import math
  6. import sys
  7. from typing import Iterable
  8. import torch
  9. from utils import distributed_utils
  10. from utils.misc import MetricLogger, SmoothedValue
  11. from utils.vis_tools import vis_data
  12. def train_one_epoch(cfg,
  13. model : torch.nn.Module,
  14. criterion : torch.nn.Module,
  15. data_loader : Iterable,
  16. optimizer : torch.optim.Optimizer,
  17. device : torch.device,
  18. epoch : int,
  19. vis_target : bool,
  20. warmup_lr_scheduler,
  21. debug :bool = False
  22. ):
  23. model.train()
  24. criterion.train()
  25. metric_logger = MetricLogger(delimiter=" ")
  26. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  27. header = 'Epoch: [{} / {}]'.format(epoch, cfg.max_epoch)
  28. epoch_size = len(data_loader)
  29. print_freq = 10
  30. optimizer.zero_grad()
  31. for iter_i, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
  32. ni = iter_i + epoch * epoch_size
  33. # WarmUp
  34. if ni % cfg.grad_accumulate == 0:
  35. ni = ni // cfg.grad_accumulate
  36. if ni < cfg.warmup_iters:
  37. warmup_lr_scheduler(ni, optimizer)
  38. elif ni == cfg.warmup_iters:
  39. print('Warmup stage is over.')
  40. warmup_lr_scheduler.set_lr(optimizer, cfg.base_lr)
  41. # To device
  42. images, masks = samples
  43. images = images.to(device)
  44. masks = masks.to(device)
  45. targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
  46. # Visualize train targets
  47. if vis_target:
  48. vis_data(images, targets, masks, cfg.class_labels, cfg.normalize_coords, cfg.box_format)
  49. # Inference
  50. outputs = model(images, masks)
  51. # Compute loss
  52. loss_dict = criterion(outputs, targets)
  53. loss_weight_dict = criterion.weight_dict
  54. losses = sum(loss_dict[k] * loss_weight_dict[k] for k in loss_dict.keys() if k in loss_weight_dict)
  55. loss_value = losses.item()
  56. losses /= cfg.grad_accumulate
  57. # Reduce losses over all GPUs for logging purposes
  58. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  59. # Check loss
  60. if not math.isfinite(loss_value):
  61. print("Loss is {}, stopping training".format(loss_value))
  62. print(loss_dict_reduced)
  63. sys.exit(1)
  64. # Backward
  65. losses.backward()
  66. # Optimize
  67. if (iter_i + 1) % cfg.grad_accumulate == 0:
  68. if cfg.clip_max_norm > 0:
  69. torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_max_norm)
  70. optimizer.step()
  71. optimizer.zero_grad()
  72. metric_logger.update(loss=loss_value, **loss_dict_reduced)
  73. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  74. if debug:
  75. print("For debug mode, we only train the model with 1 iteration.")
  76. break
  77. # gather the stats from all processes
  78. metric_logger.synchronize_between_processes()
  79. print("Averaged stats:", metric_logger)
  80. return {k: meter.global_avg for k, meter in metric_logger.meters.items()}