diff --git a/dmlcloud/util/distributed.py b/dmlcloud/util/distributed.py index b76d211..4db2e65 100644 --- a/dmlcloud/util/distributed.py +++ b/dmlcloud/util/distributed.py @@ -79,14 +79,17 @@ def print_worker(msg, barrier=True, flush=True): dist.barrier() -def init_process_group_dummy(): +def init_process_group_dummy(**kwargs): """ Initializes the process group with a single process. Uses HashStore under the hood. Useful for applications that only run on a single gpu. """ + backend = kwargs.get('backend', None) + if backend is None: + backend = 'cpu:gloo,cuda:nccl' if dist.is_nccl_available() else 'gloo' store = dist.HashStore() - dist.init_process_group(store=store, rank=0, world_size=1, backend='gloo') + dist.init_process_group(store=store, rank=0, world_size=1, backend=backend, **kwargs) def init_process_group_MPI(ip_idx=0, port=None, **kwargs):