distributed_utils.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import os
  3. import subprocess
  4. import torch
  5. import torch.distributed as dist
  6. def reduce_dict(input_dict, average=True):
  7. """
  8. Args:
  9. input_dict (dict): all the values will be reduced
  10. average (bool): whether to do average or sum
  11. Reduce the values in the dictionary from all processes so that all processes
  12. have the averaged results. Returns a dict with the same fields as
  13. input_dict, after reduction.
  14. """
  15. world_size = get_world_size()
  16. if world_size < 2:
  17. return input_dict
  18. with torch.no_grad():
  19. names = []
  20. values = []
  21. # sort the keys so that they are consistent across processes
  22. for k in sorted(input_dict.keys()):
  23. names.append(k)
  24. values.append(input_dict[k])
  25. values = torch.stack(values, dim=0)
  26. dist.all_reduce(values)
  27. if average:
  28. values /= world_size
  29. reduced_dict = {k: v for k, v in zip(names, values)}
  30. return reduced_dict
  31. def get_sha():
  32. cwd = os.path.dirname(os.path.abspath(__file__))
  33. def _run(command):
  34. return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
  35. sha = 'N/A'
  36. diff = "clean"
  37. branch = 'N/A'
  38. try:
  39. sha = _run(['git', 'rev-parse', 'HEAD'])
  40. subprocess.check_output(['git', 'diff'], cwd=cwd)
  41. diff = _run(['git', 'diff-index', 'HEAD'])
  42. diff = "has uncommited changes" if diff else "clean"
  43. branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
  44. except Exception:
  45. pass
  46. message = f"sha: {sha}, status: {diff}, branch: {branch}"
  47. return message
  48. def setup_for_distributed(is_master):
  49. """
  50. This function disables printing when not in master process
  51. """
  52. import builtins as __builtin__
  53. builtin_print = __builtin__.print
  54. def print(*args, **kwargs):
  55. force = kwargs.pop('force', False)
  56. if is_master or force:
  57. builtin_print(*args, **kwargs)
  58. __builtin__.print = print
  59. def is_dist_avail_and_initialized():
  60. if not dist.is_available():
  61. return False
  62. if not dist.is_initialized():
  63. return False
  64. return True
  65. def get_world_size():
  66. if not is_dist_avail_and_initialized():
  67. return 1
  68. return dist.get_world_size()
  69. def get_rank():
  70. if not is_dist_avail_and_initialized():
  71. return 0
  72. return dist.get_rank()
  73. def is_main_process():
  74. return get_rank() == 0
  75. def save_on_master(*args, **kwargs):
  76. if is_main_process():
  77. torch.save(*args, **kwargs)
  78. def init_distributed_mode(args):
  79. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  80. args.rank = int(os.environ["RANK"])
  81. args.world_size = int(os.environ['WORLD_SIZE'])
  82. args.gpu = int(os.environ['LOCAL_RANK'])
  83. elif 'SLURM_PROCID' in os.environ:
  84. args.rank = int(os.environ['SLURM_PROCID'])
  85. args.gpu = args.rank % torch.cuda.device_count()
  86. else:
  87. print('Not using distributed mode')
  88. args.distributed = False
  89. return
  90. args.distributed = True
  91. torch.cuda.set_device(args.gpu)
  92. args.dist_backend = 'nccl'
  93. print('| distributed init (rank {}): {}'.format(
  94. args.rank, args.dist_url), flush=True)
  95. torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
  96. world_size=args.world_size, rank=args.rank)
  97. torch.distributed.barrier()
  98. setup_for_distributed(args.rank == 0)