engine.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import math
  2. import sys
  3. from typing import Iterable
  4. import torch
  5. from utils import distributed_utils
  6. from utils.misc import MetricLogger, SmoothedValue
  7. from utils.vis_tools import vis_data
  8. def train_one_epoch(cfg,
  9. model : torch.nn.Module,
  10. criterion : torch.nn.Module,
  11. data_loader : Iterable,
  12. optimizer : torch.optim.Optimizer,
  13. device : torch.device,
  14. epoch : int,
  15. vis_target : bool,
  16. warmup_lr_scheduler,
  17. debug :bool = False
  18. ):
  19. model.train()
  20. criterion.train()
  21. metric_logger = MetricLogger(delimiter=" ")
  22. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  23. header = 'Epoch: [{} / {}]'.format(epoch, cfg.max_epoch)
  24. epoch_size = len(data_loader)
  25. print_freq = 10
  26. optimizer.zero_grad()
  27. for iter_i, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
  28. ni = iter_i + epoch * epoch_size
  29. # WarmUp
  30. if ni % cfg.grad_accumulate == 0:
  31. ni = ni // cfg.grad_accumulate
  32. if ni < cfg.warmup_iters:
  33. warmup_lr_scheduler(ni, optimizer)
  34. elif ni == cfg.warmup_iters:
  35. print('Warmup stage is over.')
  36. warmup_lr_scheduler.set_lr(optimizer, cfg.base_lr)
  37. # To device
  38. images, masks = samples
  39. images = images.to(device)
  40. masks = masks.to(device)
  41. targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
  42. # Visualize train targets
  43. if vis_target:
  44. vis_data(images, targets, masks, cfg.class_labels, cfg.normalize_coords, cfg.box_format)
  45. # Inference
  46. outputs = model(images, masks)
  47. # Compute loss
  48. loss_dict = criterion(outputs, targets)
  49. losses = loss_dict["losses"]# sum(loss_dict[k] * loss_weight_dict[k] for k in loss_dict.keys() if k in loss_weight_dict)
  50. loss_value = losses.item()
  51. losses /= cfg.grad_accumulate
  52. # Reduce losses over all GPUs for logging purposes
  53. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  54. # Check loss
  55. if not math.isfinite(loss_value):
  56. print("Loss is {}, stopping training".format(loss_value))
  57. print(loss_dict_reduced)
  58. sys.exit(1)
  59. # Backward
  60. losses.backward()
  61. # Optimize
  62. if (iter_i + 1) % cfg.grad_accumulate == 0:
  63. if cfg.clip_max_norm > 0:
  64. torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_max_norm)
  65. optimizer.step()
  66. optimizer.zero_grad()
  67. metric_logger.update(loss=loss_value, **loss_dict_reduced)
  68. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  69. if debug:
  70. print("For debug mode, we only train the model with 1 iteration.")
  71. break
  72. # gather the stats from all processes
  73. metric_logger.synchronize_between_processes()
  74. print("Averaged stats:", metric_logger)
  75. return {k: meter.global_avg for k, meter in metric_logger.meters.items()}