| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- # from github: https://github.com/ruinmessi/ASFF/blob/master/utils/distributed_util.py
- import torch
- import torch.distributed as dist
- import os
- import subprocess
- import pickle
- def all_gather(data):
- """
- Run all_gather on arbitrary picklable data (not necessarily tensors)
- Args:
- data: any picklable object
- Returns:
- list[data]: list of data gathered from each rank
- """
- world_size = get_world_size()
- if world_size == 1:
- return [data]
- # serialized to a Tensor
- buffer = pickle.dumps(data)
- storage = torch.ByteStorage.from_buffer(buffer)
- tensor = torch.ByteTensor(storage).to("cuda")
- # obtain Tensor size of each rank
- local_size = torch.tensor([tensor.numel()], device="cuda")
- size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
- dist.all_gather(size_list, local_size)
- size_list = [int(size.item()) for size in size_list]
- max_size = max(size_list)
- # receiving Tensor from all ranks
- # we pad the tensor because torch all_gather does not support
- # gathering tensors of different shapes
- tensor_list = []
- for _ in size_list:
- tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
- if local_size != max_size:
- padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
- tensor = torch.cat((tensor, padding), dim=0)
- dist.all_gather(tensor_list, tensor)
- data_list = []
- for size, tensor in zip(size_list, tensor_list):
- buffer = tensor.cpu().numpy().tobytes()[:size]
- data_list.append(pickle.loads(buffer))
- return data_list
- def reduce_dict(input_dict, average=True):
- """
- Args:
- input_dict (dict): all the values will be reduced
- average (bool): whether to do average or sum
- Reduce the values in the dictionary from all processes so that all processes
- have the averaged results. Returns a dict with the same fields as
- input_dict, after reduction.
- """
- world_size = get_world_size()
- if world_size < 2:
- return input_dict
- with torch.no_grad():
- names = []
- values = []
- # sort the keys so that they are consistent across processes
- for k in sorted(input_dict.keys()):
- names.append(k)
- values.append(input_dict[k])
- values = torch.stack(values, dim=0)
- dist.all_reduce(values)
- if average:
- values /= world_size
- reduced_dict = {k: v for k, v in zip(names, values)}
- return reduced_dict
- def get_sha():
- cwd = os.path.dirname(os.path.abspath(__file__))
- def _run(command):
- return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
- sha = 'N/A'
- diff = "clean"
- branch = 'N/A'
- try:
- sha = _run(['git', 'rev-parse', 'HEAD'])
- subprocess.check_output(['git', 'diff'], cwd=cwd)
- diff = _run(['git', 'diff-index', 'HEAD'])
- diff = "has uncommited changes" if diff else "clean"
- branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
- except Exception:
- pass
- message = f"sha: {sha}, status: {diff}, branch: {branch}"
- return message
- def setup_for_distributed(is_master):
- """
- This function disables printing when not in master process
- """
- import builtins as __builtin__
- builtin_print = __builtin__.print
- def print(*args, **kwargs):
- force = kwargs.pop('force', False)
- if is_master or force:
- builtin_print(*args, **kwargs)
- __builtin__.print = print
- def is_dist_avail_and_initialized():
- if not dist.is_available():
- return False
- if not dist.is_initialized():
- return False
- return True
- def get_world_size():
- if not is_dist_avail_and_initialized():
- return 1
- return dist.get_world_size()
- def get_rank():
- if not is_dist_avail_and_initialized():
- return 0
- return dist.get_rank()
- def is_main_process():
- return get_rank() == 0
- def save_on_master(*args, **kwargs):
- if is_main_process():
- torch.save(*args, **kwargs)
- def init_distributed_mode(args):
- if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
- args.rank = int(os.environ["RANK"])
- args.world_size = int(os.environ['WORLD_SIZE'])
- args.gpu = int(os.environ['LOCAL_RANK'])
- elif 'SLURM_PROCID' in os.environ:
- args.rank = int(os.environ['SLURM_PROCID'])
- args.gpu = args.rank % torch.cuda.device_count()
- else:
- print('Not using distributed mode')
- args.distributed = False
- return
- args.distributed = True
- torch.cuda.set_device(args.gpu)
- args.dist_backend = 'nccl'
- print('| distributed init (rank {}): {}'.format(
- args.rank, args.dist_url), flush=True)
- torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
- world_size=args.world_size, rank=args.rank)
- torch.distributed.barrier()
- setup_for_distributed(args.rank == 0)
|