engine.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. Train and eval functions used in main.py
  4. """
  5. import math
  6. import sys
  7. from typing import Iterable
  8. import torch
  9. from utils import distributed_utils
  10. from utils.misc import MetricLogger, SmoothedValue
  11. from utils.vis_tools import vis_data
  12. def train_one_epoch(cfg,
  13. model : torch.nn.Module,
  14. criterion : torch.nn.Module,
  15. data_loader : Iterable,
  16. optimizer : torch.optim.Optimizer,
  17. device : torch.device,
  18. epoch : int,
  19. vis_target : bool,
  20. warmup_lr_scheduler,
  21. debug :bool = False
  22. ):
  23. model.train()
  24. criterion.train()
  25. metric_logger = MetricLogger(delimiter=" ")
  26. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  27. header = 'Epoch: [{} / {}]'.format(epoch, cfg.max_epoch)
  28. epoch_size = len(data_loader)
  29. print_freq = 10
  30. optimizer.zero_grad()
  31. for iter_i, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
  32. ni = iter_i + epoch * epoch_size
  33. # WarmUp
  34. if ni < cfg.warmup_iters:
  35. warmup_lr_scheduler(ni, optimizer)
  36. elif ni == cfg.warmup_iters:
  37. print('Warmup stage is over.')
  38. warmup_lr_scheduler.set_lr(optimizer, cfg.base_lr)
  39. # To device
  40. images, masks = samples
  41. images = images.to(device)
  42. masks = masks.to(device)
  43. targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
  44. # Visualize train targets
  45. if vis_target:
  46. vis_data(images, targets, masks, cfg.class_labels, cfg.normalize_coords, cfg.box_format)
  47. # Inference
  48. outputs = model(images, masks)
  49. # Compute loss
  50. loss_dict = criterion(outputs, targets)
  51. loss_weight_dict = criterion.weight_dict
  52. losses = sum(loss_dict[k] * loss_weight_dict[k] for k in loss_dict.keys() if k in loss_weight_dict)
  53. loss_value = losses.item()
  54. losses /= cfg.grad_accumulate
  55. # Reduce losses over all GPUs for logging purposes
  56. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  57. # Check loss
  58. if not math.isfinite(loss_value):
  59. print("Loss is {}, stopping training".format(loss_value))
  60. print(loss_dict_reduced)
  61. sys.exit(1)
  62. # Backward
  63. losses.backward()
  64. # Optimize
  65. if (iter_i + 1) % cfg.grad_accumulate == 0:
  66. if cfg.clip_max_norm > 0:
  67. torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_max_norm)
  68. optimizer.step()
  69. optimizer.zero_grad()
  70. metric_logger.update(loss=loss_value, **loss_dict_reduced)
  71. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  72. if debug:
  73. print("For debug mode, we only train the model with 1 iteration.")
  74. break
  75. # gather the stats from all processes
  76. metric_logger.synchronize_between_processes()
  77. print("Averaged stats:", metric_logger)
  78. return {k: meter.global_avg for k, meter in metric_logger.meters.items()}