distributed_utils.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # from github: https://github.com/ruinmessi/ASFF/blob/master/utils/distributed_util.py
  2. import torch
  3. import torch.distributed as dist
  4. import os
  5. import subprocess
  6. import pickle
  7. def all_gather(data):
  8. """
  9. Run all_gather on arbitrary picklable data (not necessarily tensors)
  10. Args:
  11. data: any picklable object
  12. Returns:
  13. list[data]: list of data gathered from each rank
  14. """
  15. world_size = get_world_size()
  16. if world_size == 1:
  17. return [data]
  18. # serialized to a Tensor
  19. buffer = pickle.dumps(data)
  20. storage = torch.ByteStorage.from_buffer(buffer)
  21. tensor = torch.ByteTensor(storage).to("cuda")
  22. # obtain Tensor size of each rank
  23. local_size = torch.tensor([tensor.numel()], device="cuda")
  24. size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
  25. dist.all_gather(size_list, local_size)
  26. size_list = [int(size.item()) for size in size_list]
  27. max_size = max(size_list)
  28. # receiving Tensor from all ranks
  29. # we pad the tensor because torch all_gather does not support
  30. # gathering tensors of different shapes
  31. tensor_list = []
  32. for _ in size_list:
  33. tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
  34. if local_size != max_size:
  35. padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
  36. tensor = torch.cat((tensor, padding), dim=0)
  37. dist.all_gather(tensor_list, tensor)
  38. data_list = []
  39. for size, tensor in zip(size_list, tensor_list):
  40. buffer = tensor.cpu().numpy().tobytes()[:size]
  41. data_list.append(pickle.loads(buffer))
  42. return data_list
  43. def reduce_dict(input_dict, average=True):
  44. """
  45. Args:
  46. input_dict (dict): all the values will be reduced
  47. average (bool): whether to do average or sum
  48. Reduce the values in the dictionary from all processes so that all processes
  49. have the averaged results. Returns a dict with the same fields as
  50. input_dict, after reduction.
  51. """
  52. world_size = get_world_size()
  53. if world_size < 2:
  54. return input_dict
  55. with torch.no_grad():
  56. names = []
  57. values = []
  58. # sort the keys so that they are consistent across processes
  59. for k in sorted(input_dict.keys()):
  60. names.append(k)
  61. values.append(input_dict[k])
  62. values = torch.stack(values, dim=0)
  63. dist.all_reduce(values)
  64. if average:
  65. values /= world_size
  66. reduced_dict = {k: v for k, v in zip(names, values)}
  67. return reduced_dict
  68. def get_sha():
  69. cwd = os.path.dirname(os.path.abspath(__file__))
  70. def _run(command):
  71. return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
  72. sha = 'N/A'
  73. diff = "clean"
  74. branch = 'N/A'
  75. try:
  76. sha = _run(['git', 'rev-parse', 'HEAD'])
  77. subprocess.check_output(['git', 'diff'], cwd=cwd)
  78. diff = _run(['git', 'diff-index', 'HEAD'])
  79. diff = "has uncommited changes" if diff else "clean"
  80. branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
  81. except Exception:
  82. pass
  83. message = f"sha: {sha}, status: {diff}, branch: {branch}"
  84. return message
  85. def setup_for_distributed(is_master):
  86. """
  87. This function disables printing when not in master process
  88. """
  89. import builtins as __builtin__
  90. builtin_print = __builtin__.print
  91. def print(*args, **kwargs):
  92. force = kwargs.pop('force', False)
  93. if is_master or force:
  94. builtin_print(*args, **kwargs)
  95. __builtin__.print = print
  96. def is_dist_avail_and_initialized():
  97. if not dist.is_available():
  98. return False
  99. if not dist.is_initialized():
  100. return False
  101. return True
  102. def get_world_size():
  103. if not is_dist_avail_and_initialized():
  104. return 1
  105. return dist.get_world_size()
  106. def get_rank():
  107. if not is_dist_avail_and_initialized():
  108. return 0
  109. return dist.get_rank()
  110. def is_main_process():
  111. return get_rank() == 0
  112. def save_on_master(*args, **kwargs):
  113. if is_main_process():
  114. torch.save(*args, **kwargs)
  115. def init_distributed_mode(args):
  116. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  117. args.rank = int(os.environ["RANK"])
  118. args.world_size = int(os.environ['WORLD_SIZE'])
  119. args.gpu = int(os.environ['LOCAL_RANK'])
  120. elif 'SLURM_PROCID' in os.environ:
  121. args.rank = int(os.environ['SLURM_PROCID'])
  122. args.gpu = args.rank % torch.cuda.device_count()
  123. else:
  124. print('Not using distributed mode')
  125. args.distributed = False
  126. return
  127. args.distributed = True
  128. torch.cuda.set_device(args.gpu)
  129. args.dist_backend = 'nccl'
  130. print('| distributed init (rank {}): {}'.format(
  131. args.rank, args.dist_url), flush=True)
  132. torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
  133. world_size=args.world_size, rank=args.rank)
  134. torch.distributed.barrier()
  135. setup_for_distributed(args.rank == 0)