engine.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. Train and eval functions used in main.py
  4. """
  5. import math
  6. import sys
  7. from typing import Iterable
  8. import torch
  9. from utils import distributed_utils
  10. from utils.misc import MetricLogger, SmoothedValue
  11. from utils.vis_tools import vis_data
  12. def train_one_epoch(cfg,
  13. model : torch.nn.Module,
  14. criterion : torch.nn.Module,
  15. data_loader : Iterable,
  16. optimizer : torch.optim.Optimizer,
  17. device : torch.device,
  18. epoch : int,
  19. vis_target : bool,
  20. warmup_lr_scheduler,
  21. debug :bool = False
  22. ):
  23. model.train()
  24. criterion.train()
  25. metric_logger = MetricLogger(delimiter=" ")
  26. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  27. header = 'Epoch: [{} / {}]'.format(epoch, cfg.max_epoch)
  28. epoch_size = len(data_loader)
  29. print_freq = 10
  30. for iter_i, (samples, targets) in metric_logger.log_every(data_loader, print_freq, header):
  31. ni = iter_i + epoch * epoch_size
  32. # WarmUp
  33. if ni < cfg.warmup_iters:
  34. warmup_lr_scheduler(ni, optimizer)
  35. elif ni == cfg.warmup_iters:
  36. print('Warmup stage is over.')
  37. warmup_lr_scheduler.set_lr(optimizer, cfg.base_lr)
  38. # To device
  39. images, masks = samples
  40. images = images.to(device)
  41. masks = masks.to(device)
  42. targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
  43. # Visualize train targets
  44. if vis_target:
  45. vis_data(images, targets, masks, cfg.class_labels, cfg.normalize_coords, cfg.box_format)
  46. # Inference
  47. outputs = model(images, masks, targets)
  48. # Compute loss
  49. loss_dict = criterion(outputs, targets)
  50. loss_weight_dict = criterion.weight_dict
  51. losses = sum(loss_dict[k] * loss_weight_dict[k] for k in loss_dict.keys() if k in loss_weight_dict)
  52. loss_value = losses.item()
  53. losses /= cfg.grad_accumulate
  54. # Reduce losses over all GPUs for logging purposes
  55. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  56. # Check loss
  57. if not math.isfinite(loss_value):
  58. print("Loss is {}, stopping training".format(loss_value))
  59. print(loss_dict_reduced)
  60. sys.exit(1)
  61. # Backward
  62. losses.backward()
  63. if cfg.clip_max_norm > 0:
  64. torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_max_norm)
  65. # Optimize
  66. if (iter_i + 1) % cfg.grad_accumulate == 0:
  67. optimizer.step()
  68. optimizer.zero_grad()
  69. metric_logger.update(loss=loss_value, **loss_dict_reduced)
  70. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  71. if debug:
  72. print("For debug mode, we only train the model with 1 iteration.")
  73. break
  74. # gather the stats from all processes
  75. metric_logger.synchronize_between_processes()
  76. print("Averaged stats:", metric_logger)
  77. return {k: meter.global_avg for k, meter in metric_logger.meters.items()}