| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- import os
- import subprocess
- import torch
- import torch.distributed as dist
- def reduce_dict(input_dict, average=True):
- """
- Args:
- input_dict (dict): all the values will be reduced
- average (bool): whether to do average or sum
- Reduce the values in the dictionary from all processes so that all processes
- have the averaged results. Returns a dict with the same fields as
- input_dict, after reduction.
- """
- world_size = get_world_size()
- if world_size < 2:
- return input_dict
- with torch.no_grad():
- names = []
- values = []
- # sort the keys so that they are consistent across processes
- for k in sorted(input_dict.keys()):
- names.append(k)
- values.append(input_dict[k])
- values = torch.stack(values, dim=0)
- dist.all_reduce(values)
- if average:
- values /= world_size
- reduced_dict = {k: v for k, v in zip(names, values)}
- return reduced_dict
- def get_sha():
- cwd = os.path.dirname(os.path.abspath(__file__))
- def _run(command):
- return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
- sha = 'N/A'
- diff = "clean"
- branch = 'N/A'
- try:
- sha = _run(['git', 'rev-parse', 'HEAD'])
- subprocess.check_output(['git', 'diff'], cwd=cwd)
- diff = _run(['git', 'diff-index', 'HEAD'])
- diff = "has uncommited changes" if diff else "clean"
- branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
- except Exception:
- pass
- message = f"sha: {sha}, status: {diff}, branch: {branch}"
- return message
- def setup_for_distributed(is_master):
- """
- This function disables printing when not in master process
- """
- import builtins as __builtin__
- builtin_print = __builtin__.print
- def print(*args, **kwargs):
- force = kwargs.pop('force', False)
- if is_master or force:
- builtin_print(*args, **kwargs)
- __builtin__.print = print
- def is_dist_avail_and_initialized():
- if not dist.is_available():
- return False
- if not dist.is_initialized():
- return False
- return True
- def get_world_size():
- if not is_dist_avail_and_initialized():
- return 1
- return dist.get_world_size()
- def get_rank():
- if not is_dist_avail_and_initialized():
- return 0
- return dist.get_rank()
- def is_main_process():
- return get_rank() == 0
- def save_on_master(*args, **kwargs):
- if is_main_process():
- torch.save(*args, **kwargs)
- def init_distributed_mode(args):
- if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
- args.rank = int(os.environ["RANK"])
- args.world_size = int(os.environ['WORLD_SIZE'])
- args.gpu = int(os.environ['LOCAL_RANK'])
- elif 'SLURM_PROCID' in os.environ:
- args.rank = int(os.environ['SLURM_PROCID'])
- args.gpu = args.rank % torch.cuda.device_count()
- else:
- print('Not using distributed mode')
- args.distributed = False
- return
- args.distributed = True
- torch.cuda.set_device(args.gpu)
- args.dist_backend = 'nccl'
- print('| distributed init (rank {}): {}'.format(
- args.rank, args.dist_url), flush=True)
- torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
- world_size=args.world_size, rank=args.rank)
- torch.distributed.barrier()
- setup_for_distributed(args.rank == 0)
|