diff --git a/torch_patches/dpcpp-v1.5-rc3.patch b/torch_patches/dpcpp-v1.5-rc3.patch index 9945938b6..30b0fe5ed 100644 --- a/torch_patches/dpcpp-v1.5-rc3.patch +++ b/torch_patches/dpcpp-v1.5-rc3.patch @@ -421,6 +421,291 @@ index 9a4c9b3..6d02405 100644 } else if (tid == DispatchKey::SparseCPUTensorId) { return DeviceType::CPU; } else if (tid == DispatchKey::SparseCUDATensorId) { +diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst +index 180d1d4..d0c6f8e 100644 +--- a/docs/source/distributed.rst ++++ b/docs/source/distributed.rst +@@ -10,7 +10,7 @@ Distributed communication package - torch.distributed + Backends + -------- + +-``torch.distributed`` supports three backends, each with ++``torch.distributed`` supports three built-in backends, each with + different capabilities. The table below shows which functions are available + for use with CPU / CUDA tensors. + MPI supports CUDA only if the implementation used to build PyTorch supports it. +@@ -39,7 +39,8 @@ MPI supports CUDA only if the implementation used to build PyTorch supports it. + +------------+-----+-----+-----+-----+-----+-----+ + | barrier | ✓ | ✘ | ✓ | ? | ✘ | ✓ | + +------------+-----+-----+-----+-----+-----+-----+ +- ++| all_to_all | ✘ | ✘ | ✓ | ? | ✘ | ✘ | +++------------+-----+-----+-----+-----+-----+-----+ + + Backends that come with PyTorch + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +@@ -319,6 +320,8 @@ Collective functions + + .. autofunction:: barrier + ++.. autofunction:: all_to_all ++ + .. autoclass:: ReduceOp + + .. class:: reduce_op +@@ -397,6 +400,26 @@ of 16 + + .. _distributed-launch: + ++Third-party backends ++-------------------- ++ ++Besides the GLOO/MPI/NCCL backends, PyTorch distributed supports third-party backends ++through a run-time register mechanism. ++For references on how to develop a third-party backend through C++ Extension, ++please refer to `Tutorials - Custom C++ and CUDA Extensions `_ and `test/cpp_extensions/cpp_c10d_extension.cpp`. ++The capability of third-party backends are decided by their own implementations. ++ ++The new backend derives from `c10d.ProcessGroup` and registers the backend name and the ++instantiating interface through :func:`torch.distributed.Backend.register_backend` when ++imported. ++ ++When manually importing this backend and invoking :func:`torch.distributed.init_process_group` ++with the corresponding backend name, the `torch.distributed` package runs on the new backend. ++ ++.. warning:: ++ The support of third-party backend is experimental and subject to change. ++ + Launch utility + -------------- + +diff --git a/setup.py b/setup.py +index 7352d3b..977f8fc 100644 +--- a/setup.py ++++ b/setup.py +@@ -811,6 +811,7 @@ if __name__ == '__main__': + 'include/c10/cuda/impl/*.h', + 'include/c10/hip/*.h', + 'include/c10/hip/impl/*.h', ++ 'include/c10d/*.hpp', + 'include/caffe2/**/*.h', + 'include/torch/*.h', + 'include/torch/csrc/*.h', +diff --git a/test/distributed/test_distributed.py b/test/distributed/test_distributed.py +index 37bc4ac..a8a1997 100644 +--- a/test/distributed/test_distributed.py ++++ b/test/distributed/test_distributed.py +@@ -17,7 +17,8 @@ import torch.cuda + import torch.distributed as dist + import torch.nn as nn + import torch.nn.functional as F +-from torch.testing._internal.common_utils import TestCase, run_tests ++from torch.testing._internal.common_utils import TestCase, run_tests, find_free_port ++from torch.distributed.distributed_c10d import _get_default_group + from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR + from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT + from torch.testing._internal.common_distributed import simple_sparse_reduce_tests, skip_if_rocm +@@ -31,6 +32,12 @@ except ImportError: + + skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") + ++CPP_EXTENSIONS_WARNING = """ ++Ninja (https://ninja-build.org) must be available to run C++ extensions tests, ++but it could not be found. Install ninja with `pip install ninja` ++or `conda install ninja`. ++""" ++ + BACKEND = os.environ["BACKEND"] + TEMP_DIR = os.environ["TEMP_DIR"] + INIT_METHOD = os.getenv("INIT_METHOD", "env://") +@@ -150,6 +157,21 @@ def skip_if_small_worldsize(func): + return wrapper + + ++def skip_if_no_ninja(func): ++ ++ @wraps(func) ++ def wrapper(*args, **kwargs): ++ try: ++ import torch.utils.cpp_extension ++ torch.utils.cpp_extension.verify_ninja_availability() ++ except RuntimeError: ++ print(CPP_EXTENSIONS_WARNING) ++ return 0 ++ ++ return func(*args, **kwargs) ++ ++ return wrapper ++ + def require_backend(backends): + if BACKEND not in backends: + return unittest.skip("Test requires backend to be one of %s" % backends) +@@ -1511,6 +1533,92 @@ class _DistTestBase(object): + output_tensors_lists, input_tensors, expected_tensors, group_id) + self._barrier() + ++ # AllToAll ++ def _test_all_to_all_single_equal_split_helper(self, group, group_id, rank): ++ if group_id is not None: ++ size = len(group) ++ in_tensor = torch.ones([size, size]) * rank ++ expected_tensor = torch.cat([torch.ones([1, size]) * i for i in group]) ++ out_tensor = torch.ones([size, size]) * -1 ++ dist.all_to_all_single(out_tensor, in_tensor, group=group_id) ++ self.assertEqual(out_tensor, expected_tensor) ++ self._barrier() ++ ++ def _test_all_to_all_single_unequal_split_helper(self, group, group_id, rank): ++ if group_id is not None: ++ size = len(group) ++ in_splits = [i + 1 for i in group] ++ out_splits = [rank + 1 for _ in group] ++ in_tensor = torch.ones([sum(in_splits), size]) * rank ++ out_tensor = torch.ones([(rank + 1) * size, size]) ++ expected_tensor = torch.cat([torch.ones([rank + 1, size]) * i for i in group]) ++ dist.all_to_all_single( ++ out_tensor, in_tensor, out_splits, in_splits, group=group_id) ++ self.assertEqual(out_tensor, expected_tensor) ++ self._barrier() ++ ++ def _test_all_to_all_helper(self, group, group_id, rank): ++ if group_id is not None: ++ size = len(group) ++ in_splits = [i + 1 for i in group] ++ in_tensors = [ ++ torch.ones([in_splits[i], size]) * rank for i, _ in enumerate(group) ++ ] ++ out_tensors = [torch.ones([(rank + 1), size]) for _ in group] ++ expected_tensors = [torch.ones([rank + 1, size]) * i for i in group] ++ dist.all_to_all(out_tensors, in_tensors, group=group_id) ++ for t1, t2 in zip(out_tensors, expected_tensors): ++ self.assertEqual(t1, t2) ++ self._barrier() ++ ++ @unittest.skipIf(BACKEND != "mpi", "Only MPI supports all_to_all_single") ++ def test_all_to_all_single_equal_split(self): ++ group, group_id, rank = self._init_global_test() ++ self._test_all_to_all_single_equal_split_helper(group, group_id, rank) ++ ++ @unittest.skipIf(BACKEND != "mpi", "Only MPI supports all_to_all_single") ++ def test_all_to_all_single_unequal_split(self): ++ group, group_id, rank = self._init_global_test() ++ self._test_all_to_all_single_unequal_split_helper(group, group_id, rank) ++ ++ @unittest.skipIf(BACKEND != "mpi", "Only MPI supports all_to_all") ++ def test_all_to_all(self): ++ group, group_id, rank = self._init_global_test() ++ self._test_all_to_all_helper(group, group_id, rank) ++ ++ @unittest.skipIf(BACKEND != "mpi", "Only MPI supports all_to_all_single") ++ @skip_if_small_worldsize ++ def test_all_to_all_single_equal_split_group(self): ++ group, group_id, rank = self._init_group_test() ++ self._test_all_to_all_single_equal_split_helper(group, group_id, rank) ++ ++ @unittest.skipIf(BACKEND != "mpi", "Only MPI supports all_to_all_single") ++ @skip_if_small_worldsize ++ def test_all_to_all_single_unequal_split_group(self): ++ group, group_id, rank = self._init_group_test() ++ self._test_all_to_all_single_unequal_split_helper(group, group_id, rank) ++ ++ @unittest.skipIf(BACKEND != "mpi", "Only MPI supports all_to_all") ++ @skip_if_small_worldsize ++ def test_all_to_all_group(self): ++ group, group_id, rank = self._init_group_test() ++ self._test_all_to_all_helper(group, group_id, rank) ++ ++ @unittest.skipIf(BACKEND != "mpi", "Only MPI supports all_to_all_single") ++ def test_all_to_all_single_equal_split_full_group(self): ++ group, group_id, rank = self._init_full_group_test() ++ self._test_all_to_all_single_equal_split_helper(group, group_id, rank) ++ ++ @unittest.skipIf(BACKEND != "mpi", "Only MPI supports all_to_all_single") ++ def test_all_to_all_single_unequal_split_full_group(self): ++ group, group_id, rank = self._init_full_group_test() ++ self._test_all_to_all_single_unequal_split_helper(group, group_id, rank) ++ ++ @unittest.skipIf(BACKEND != "mpi", "Only MPI supports all_to_all") ++ def test_all_to_all_full_group(self): ++ group, group_id, rank = self._init_full_group_test() ++ self._test_all_to_all_helper(group, group_id, rank) ++ + # BARRIER + def _test_barrier_helper( + self, group, group_id, rank, cuda=False, rank_to_GPU=None): +@@ -2181,6 +2289,45 @@ elif BACKEND == "mpi": + class TestMPI(TestCase, _DistTestBase): + pass + ++elif BACKEND == "test": ++ class TestBackendDynamicLoad(TestCase): ++ def setUp(self): ++ super(TestBackendDynamicLoad, self).setUp() ++ ++ def _load_test_backend(self): ++ temp_dir = tempfile.mkdtemp() ++ src = "{}/../cpp_extensions/cpp_c10d_extension.cpp".format(os.path.abspath(os.path.dirname(__file__))) ++ extension = torch.utils.cpp_extension.load( ++ name="torch_test", ++ sources=[src], ++ build_directory=temp_dir ++ ) ++ ++ @skip_if_no_ninja ++ def test_backend_apis(self): ++ self._load_test_backend() ++ ++ os.environ['WORLD_SIZE'] = '1' ++ os.environ['MASTER_ADDR'] = '127.0.0.1' ++ os.environ['MASTER_PORT'] = str(find_free_port()) ++ os.environ['RANK'] = '0' ++ ++ dist.init_process_group(backend='test', init_method='env://', world_size=1, rank=0) ++ self.assertEqual(dist.get_rank(), 0) ++ self.assertEqual(dist.get_world_size(), 1) ++ ++ process_group = _get_default_group() ++ work = process_group.allreduce([torch.rand(1), torch.rand(1)]) ++ self.assertTrue(work.wait()) ++ self.assertTrue(work.is_completed()) ++ self.assertTrue(work.is_success()) ++ ++ work = process_group.broadcast([torch.rand(1)]) ++ self.assertTrue(work.wait()) ++ self.assertTrue(work.is_completed()) ++ self.assertTrue(work.is_success()) ++ ++ dist.destroy_process_group() + + if __name__ == "__main__": + assert ( +diff --git a/test/run_test.py b/test/run_test.py +index f9ffeae..7e82b87 100755 +--- a/test/run_test.py ++++ b/test/run_test.py +@@ -148,6 +148,9 @@ DISTRIBUTED_TESTS_CONFIG = {} + + + if dist.is_available(): ++ DISTRIBUTED_TESTS_CONFIG['test'] = { ++ 'WORLD_SIZE': '1' ++ } + if not TEST_WITH_ROCM and dist.is_mpi_available(): + DISTRIBUTED_TESTS_CONFIG['mpi'] = { + 'WORLD_SIZE': '3', +diff --git a/test/test_determination.py b/test/test_determination.py +index 319abb0..b9ff7c4 100644 +--- a/test/test_determination.py ++++ b/test/test_determination.py +@@ -92,6 +92,7 @@ class DeterminationTest(unittest.TestCase): + self.assertEqual( + self.determined_tests(["torch/utils/cpp_extension.py"]), + [ ++ "distributed/test_distributed", + "test_cpp_extensions_aot_ninja", + "test_cpp_extensions_aot_no_ninja", + "test_determination", diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index 2a9dc9d..9410392 100644 --- a/tools/autograd/templates/python_variable_methods.cpp @@ -454,6 +739,77 @@ index 2a9dc9d..9410392 100644 {"data_ptr", (PyCFunction)THPVariable_data_ptr, METH_NOARGS, NULL}, {"dim", (PyCFunction)THPVariable_dim, METH_NOARGS, NULL}, {"has_names", (PyCFunction)THPVariable_has_names, METH_NOARGS, NULL}, +diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp +index dce5201..3b2e4fe 100644 +--- a/torch/csrc/distributed/c10d/init.cpp ++++ b/torch/csrc/distributed/c10d/init.cpp +@@ -204,6 +204,10 @@ They are used in specifying strategies for reduction collectives, e.g., + .def(py::init<>()) + .def_readwrite("timeout", &::c10d::BarrierOptions::timeout); + ++ py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions") ++ .def(py::init<>()) ++ .def_readwrite("timeout", &::c10d::AllToAllOptions::timeout); ++ + auto store = + py::class_<::c10d::Store, std::shared_ptr<::c10d::Store>, PythonStore>( + module, "Store") +@@ -470,6 +474,55 @@ They are used in specifying strategies for reduction collectives, e.g., + py::call_guard()) + + .def( ++ "alltoall_base", ++ &::c10d::ProcessGroup::alltoall_base, ++ py::arg("output_tensor"), ++ py::arg("input_tensor"), ++ py::arg("output_split_sizes"), ++ py::arg("input_split_sizes"), ++ py::arg("opts") = ::c10d::AllToAllOptions(), ++ py::call_guard()) ++ ++ .def( ++ "alltoall_base", ++ [](::c10d::ProcessGroup& pg, ++ at::Tensor& output, ++ at::Tensor& input, ++ std::vector outputSplitSizes, ++ std::vector inputSplitSizes) { ++ return pg.alltoall_base( ++ output, ++ input, ++ outputSplitSizes, ++ inputSplitSizes, ++ ::c10d::AllToAllOptions()); ++ }, ++ py::arg("output"), ++ py::arg("input"), ++ py::arg("output_split_sizes"), ++ py::arg("input_split_sizes"), ++ py::call_guard()) ++ ++ .def( ++ "alltoall", ++ &::c10d::ProcessGroup::alltoall, ++ py::arg("output_tensor"), ++ py::arg("input_tensor"), ++ py::arg("opts") = ::c10d::AllToAllOptions(), ++ py::call_guard()) ++ ++ .def( ++ "alltoall", ++ [](::c10d::ProcessGroup& pg, ++ std::vector& output, ++ std::vector& input) { ++ return pg.alltoall(output, input, ::c10d::AllToAllOptions()); ++ }, ++ py::arg("output"), ++ py::arg("input"), ++ py::call_guard()) ++ ++ .def( + "send", + &::c10d::ProcessGroup::send, + py::call_guard()) diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp index a6a9fca..d42a05f 100644 --- a/torch/csrc/jit/passes/quantization.cpp @@ -751,3 +1107,553 @@ index e6b851a..3fa1a88 100644 default: AT_ERROR("Unimplemented backend ", backend); } } +diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py +index 4ca9596..4027312 100644 +--- a/torch/distributed/distributed_c10d.py ++++ b/torch/distributed/distributed_c10d.py +@@ -11,6 +11,7 @@ from .rendezvous import rendezvous, register_rendezvous_handler # noqa: F401 + from . import ( + AllreduceOptions, + AllreduceCoalescedOptions, ++ AllToAllOptions, + BroadcastOptions, + GatherOptions, + ReduceOptions, +@@ -44,7 +45,8 @@ except ImportError: + + class Backend(object): + """ +- An enum-like class of available backends: GLOO, NCCL, and MPI. ++ An enum-like class of available backends: GLOO, NCCL, MPI, and other registered ++ backends. + + The values of this class are lowercase strings, e.g., ``"gloo"``. They can + be accessed as attributes, e.g., ``Backend.NCCL``. +@@ -75,8 +77,29 @@ class Backend(object): + "on CPU tensors.") + elif value == Backend.UNDEFINED: + raise ValueError("Invalid backend: '{}'".format(name)) ++ elif value != Backend.GLOO and value != Backend.NCCL and value != Backend.MPI: ++ value = name + return value + ++ @classmethod ++ def register_backend(cls, name, func): ++ """ ++ Registers a new backend. ++ ++ This class method is used by 3rd party cpp extension to register new backend. ++ ++ Arguments: ++ name (str): Backend name matching with the one in `init_process_group()`. ++ func (function): Function handler that instantiates the backend. ++ The function should be implemented in the backend cpp extension ++ and takes four arguments, including prefix_store, rank, ++ world_size, and timeout. ++ ++ .. note:: This support of 3rd party backend is experimental and subject to change. ++ ++ """ ++ setattr(Backend, name.upper(), func) ++ + # `_backend`, `dist_backend`, and `reduce_op` are here to maintain backward + # compatibility with pre-c10d distributed package. + # TODO: remove them when users are ready to take a hard dependency on PyTorch 1. +@@ -483,7 +506,13 @@ def _new_process_group_helper(world_size, + _pg_map[pg] = (Backend.NCCL, store) + _pg_names[pg] = group_name + else: +- raise RuntimeError("Unsupported distributed backend by group") ++ pg = getattr(Backend, backend.upper())( ++ prefix_store, ++ rank, ++ world_size, ++ timeout) ++ _pg_map[pg] = (backend, store) ++ _pg_names[pg] = group_name + + return pg + +@@ -1461,6 +1490,193 @@ def reduce_scatter(output, + work.wait() + + ++def all_to_all_single(output, ++ input, ++ output_split_sizes=None, ++ input_split_sizes=None, ++ group=group.WORLD, ++ async_op=False): ++ """ ++ Each process splits input tensor and then scatters the split list ++ to all processes in a group. Then concatenate the received tensors from all ++ the processes in the group and return single output tensor. ++ ++ Arguments: ++ output (Tensor): Gathered cancatenated output tensor. ++ input (Tensor): Input tensor to scatter. ++ output_split_sizes: (list[Int], optional): Output split sizes for dim 0 ++ if specified None or empty, dim 0 of ``output`` tensor must divide ++ equally by ``world_size``. ++ input_split_sizes: (list[Int], optional): Input split sizes for dim 0 ++ if specified None or empty, dim 0 of ``input`` tensor must divide ++ equally by ``world_size``. ++ group (ProcessGroup, optional): The process group to work on. ++ async_op (bool, optional): Whether this op should be an async op. ++ ++ Returns: ++ Async work handle, if async_op is set to True. ++ None, if not async_op or if not part of the group. ++ ++ .. warning:: ++ `all_to_all_single` is experimental and subject to change. ++ ++ Examples: ++ >>> input = torch.arange(4) + rank * 4 ++ >>> input ++ tensor([0, 1, 2, 3]) # Rank 0 ++ tensor([4, 5, 6, 7]) # Rank 1 ++ tensor([8, 9, 10, 11]) # Rank 2 ++ tensor([12, 13, 14, 15]) # Rank 3 ++ >>> output = torch.empty([4], dtype=torch.int64) ++ >>> dist.all_to_all_single(output, input) ++ >>> output ++ tensor([0, 4, 8, 12]) # Rank 0 ++ tensor([1, 5, 9, 13]) # Rank 1 ++ tensor([2, 6, 10, 14]) # Rank 2 ++ tensor([3, 7, 11, 15]) # Rank 3 ++ ++ >>> # Essentially, it is similar to following operation: ++ >>> scatter_list = list(input.chunk(world_size)) ++ >>> gather_list = list(output.chunk(world_size)) ++ >>> for i in range(world_size): ++ >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i) ++ ++ >>> # Another example with uneven split ++ >>> input ++ tensor([0, 1, 2, 3, 4, 5]) # Rank 0 ++ tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1 ++ tensor([20, 21, 22, 23, 24]) # Rank 2 ++ tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3 ++ >>> input_splits ++ [2, 2, 1, 1] # Rank 0 ++ [3, 2, 2, 2] # Rank 1 ++ [2, 1, 1, 1] # Rank 2 ++ [2, 2, 2, 1] # Rank 3 ++ >>> output_splits ++ [2, 3, 2, 2] # Rank 0 ++ [2, 2, 1, 2] # Rank 1 ++ [1, 2, 1, 2] # Rank 2 ++ [1, 2, 1, 1] # Rank 3 ++ >>> output = ... ++ >>> dist.all_to_all_single(output, input, output_splits, input_splits) ++ >>> output ++ tensor([ 0, 1, 10, 11, 12, 20, 21, 30, 31]) # Rank 0 ++ tensor([ 2, 3, 13, 14, 22, 32, 33]) # Rank 1 ++ tensor([ 4, 15, 16, 23, 34, 35]) # Rank 2 ++ tensor([ 5, 17, 18, 24, 36]) # Rank 3 ++ """ ++ if _rank_not_in_group(group): ++ return ++ ++ opts = AllToAllOptions() ++ _check_single_tensor(output, "output") ++ _check_single_tensor(input, "input") ++ output_split_sizes = [] if output_split_sizes is None else output_split_sizes ++ input_split_sizes = [] if input_split_sizes is None else input_split_sizes ++ ++ if group == GroupMember.WORLD: ++ _check_default_pg() ++ work = _default_pg.alltoall_base(output, input, output_split_sizes, input_split_sizes, opts) ++ else: ++ work = group.alltoall_base(output, input, output_split_sizes, input_split_sizes, opts) ++ ++ if async_op: ++ return work ++ else: ++ work.wait() ++ ++def all_to_all(output_tensor_list, ++ input_tensor_list, ++ group=group.WORLD, ++ async_op=False): ++ """ ++ Each process scatters list of input tensors to all processes in a group and ++ return gathered list of tensors in output list. ++ ++ Arguments: ++ output_tensor_list (list[Tensor]): List of tensors to be gathered one ++ per rank. ++ input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. ++ group (ProcessGroup, optional): The process group to work on. ++ async_op (bool, optional): Whether this op should be an async op. ++ ++ Returns: ++ Async work handle, if async_op is set to True. ++ None, if not async_op or if not part of the group. ++ ++ .. warning:: ++ `all_to_all` is experimental and subject to change. ++ ++ Examples: ++ >>> input = torch.arange(4) + rank * 4 ++ >>> input = list(input.chunk(4)) ++ >>> input ++ [tensor([0]), tensor([1]), tensor([2]), tensor([3])] # Rank 0 ++ [tensor([4]), tensor([5]), tensor([6]), tensor([7])] # Rank 1 ++ [tensor([8]), tensor([9]), tensor([10]), tensor([11])] # Rank 2 ++ [tensor([12]), tensor([13]), tensor([14]), tensor([15])] # Rank 3 ++ >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4)) ++ >>> dist.all_to_all(output, input) ++ >>> output ++ [tensor([0]), tensor([4]), tensor([8]), tensor([12])] # Rank 0 ++ [tensor([1]), tensor([5]), tensor([9]), tensor([13])] # Rank 1 ++ [tensor([2]), tensor([6]), tensor([10]), tensor([14])] # Rank 2 ++ [tensor([3]), tensor([7]), tensor([11]), tensor([15])] # Rank 3 ++ ++ >>> # Essentially, it is similar to following operation: ++ >>> scatter_list = input ++ >>> gather_list = output ++ >>> for i in range(world_size): ++ >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i) ++ ++ >>> input ++ tensor([0, 1, 2, 3, 4, 5]) # Rank 0 ++ tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1 ++ tensor([20, 21, 22, 23, 24]) # Rank 2 ++ tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3 ++ >>> input_splits ++ [2, 2, 1, 1] # Rank 0 ++ [3, 2, 2, 2] # Rank 1 ++ [2, 1, 1, 1] # Rank 2 ++ [2, 2, 2, 1] # Rank 3 ++ >>> output_splits ++ [2, 3, 2, 2] # Rank 0 ++ [2, 2, 1, 2] # Rank 1 ++ [1, 2, 1, 2] # Rank 2 ++ [1, 2, 1, 1] # Rank 3 ++ >>> input = list(input.split(input_splits)) ++ >>> input ++ [tensor([0, 1]), tensor([2, 3]), tensor([4]), tensor([5])] # Rank 0 ++ [tensor([10, 11, 12]), tensor([13, 14]), tensor([15, 16]), tensor([17, 18])] # Rank 1 ++ [tensor([20, 21]), tensor([22]), tensor([23]), tensor([24])] # Rank 2 ++ [tensor([30, 31]), tensor([32, 33]), tensor([34, 35]), tensor([36])] # Rank 3 ++ >>> output = ... ++ >>> dist.all_to_all(output, input) ++ >>> output ++ [tensor([0, 1]), tensor([10, 11, 12]), tensor([20, 21]), tensor([30, 31])] # Rank 0 ++ [tensor([2, 3]), tensor([13, 14]), tensor([22]), tensor([32, 33])] # Rank 1 ++ [tensor([4]), tensor([15, 16]), tensor([23]), tensor([34, 35])] # Rank 2 ++ [tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3 ++ """ ++ if _rank_not_in_group(group): ++ return ++ ++ opts = AllToAllOptions() ++ _check_tensor_list(output_tensor_list, "output_tensor_list") ++ _check_tensor_list(input_tensor_list, "input_tensor_list") ++ ++ if group == GroupMember.WORLD: ++ _check_default_pg() ++ work = _default_pg.alltoall(output_tensor_list, input_tensor_list, opts) ++ else: ++ work = group.alltoall(output_tensor_list, input_tensor_list, opts) ++ ++ if async_op: ++ return work ++ else: ++ work.wait() ++ ++ + def barrier(group=group.WORLD, + async_op=False): + """ +diff --git a/torch/lib/c10d/ProcessGroup.hpp b/torch/lib/c10d/ProcessGroup.hpp +index ac29f13..c98b1a9 100644 +--- a/torch/lib/c10d/ProcessGroup.hpp ++++ b/torch/lib/c10d/ProcessGroup.hpp +@@ -162,6 +162,22 @@ class ProcessGroup { + std::vector>& inputTensors, + const ReduceScatterOptions& opts = ReduceScatterOptions()) = 0; + ++ virtual std::shared_ptr alltoall_base( ++ at::Tensor& outputTensor, ++ at::Tensor& inputTensor, ++ std::vector& outputSplitSizes, ++ std::vector& inputSplitSizes, ++ const AllToAllOptions& opts = AllToAllOptions()) { ++ throw std::runtime_error("ProcessGroup does not support alltoall"); ++ } ++ ++ virtual std::shared_ptr alltoall( ++ std::vector& outputTensors, ++ std::vector& inputTensors, ++ const AllToAllOptions& opts = AllToAllOptions()) { ++ throw std::runtime_error("ProcessGroup does not support alltoall"); ++ } ++ + virtual std::shared_ptr send( + std::vector& tensors, + int dstRank, +diff --git a/torch/lib/c10d/ProcessGroupMPI.cpp b/torch/lib/c10d/ProcessGroupMPI.cpp +index d09620a..a822783 100644 +--- a/torch/lib/c10d/ProcessGroupMPI.cpp ++++ b/torch/lib/c10d/ProcessGroupMPI.cpp +@@ -1,5 +1,6 @@ + #include + ++#include + #include + + #include +@@ -91,6 +92,72 @@ void checkSameSizeAndType( + } + } + ++void checkSplitSizes( ++ const std::vector& split_sizes, ++ const at::Tensor& tensor, ++ int group_size) { ++ if (split_sizes.size() == 0) { ++ TORCH_CHECK( ++ tensor.size(0) % group_size == 0, ++ "Tensor's dim 0 does not divide equally across group size"); ++ } else { ++ TORCH_CHECK( ++ split_sizes.size() == group_size, ++ "Number of tensor splits not equal to group size"); ++ int sum = std::accumulate(split_sizes.begin(), split_sizes.end(), 0); ++ TORCH_CHECK( ++ sum == tensor.size(0), "Split sizes doesn't match total dim 0 size"); ++ } ++} ++ ++int64_t computeLengthsAndOffsets( ++ const std::vector& split_sizes, ++ const at::Tensor& tensor, ++ std::vector* lengths, ++ std::vector* offsets) { ++ int64_t group_size = lengths->size(); ++ bool equal_splits = false; ++ int64_t dim0_size = tensor.size(0); ++ int64_t row_size = (dim0_size ? tensor.numel() / dim0_size : 1); ++ int64_t split_size = 0; ++ int64_t offset = 0; ++ ++ if (split_sizes.size() == 0) { ++ equal_splits = true; ++ split_size = tensor.size(0) / group_size; ++ } ++ for (int i = 0; i < group_size; i++) { ++ int64_t length = row_size * (equal_splits ? split_size : split_sizes[i]); ++ TORCH_INTERNAL_ASSERT( ++ length <= std::numeric_limits::max() && ++ offset <= std::numeric_limits::max(), ++ "Length or offset larger than INT_MAX not supported"); ++ (*lengths)[i] = length; ++ (*offsets)[i] = offset; ++ offset += length; ++ } ++ return offset; ++} ++ ++int64_t computeLengthsAndOffsets( ++ const std::vector& tensors, ++ std::vector* lengths, ++ std::vector* offsets) { ++ int64_t group_size = lengths->size(); ++ int64_t offset = 0; ++ for (int i = 0; i < group_size; i++) { ++ int64_t length = tensors[i].numel(); ++ TORCH_INTERNAL_ASSERT( ++ length <= std::numeric_limits::max() && ++ offset <= std::numeric_limits::max(), ++ "Length or offset larger than INT_MAX not supported"); ++ (*lengths)[i] = length; ++ (*offsets)[i] = offset; ++ offset += length; ++ } ++ return offset; ++} ++ + } // namespace + + ProcessGroupMPI::AsyncWork::AsyncWork(at::Tensor tensor, MPI_Request request) +@@ -588,6 +655,139 @@ std::shared_ptr ProcessGroupMPI::reduce_scatter( + throw std::runtime_error("ProcessGroupMPI does not support reduce_scatter"); + } + ++std::shared_ptr ProcessGroupMPI::alltoall_base( ++ at::Tensor& outputTensor, ++ at::Tensor& inputTensor, ++ std::vector& outputSplitSizes, ++ std::vector& inputSplitSizes, ++ const AllToAllOptions& opts) { ++ checkSingleTensorHelper(inputTensor); ++ checkSingleTensorHelper(outputTensor); ++ ++ if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { ++ // We can use alltoall ++ TORCH_CHECK( ++ outputTensor.numel() == inputTensor.numel() && ++ outputTensor.type() == inputTensor.type(), ++ "Tensors are not equal in size or data type"); ++ TORCH_CHECK( ++ outputTensor.size(0) % size_ == 0, ++ "Tensor's dim 0 does not divide equally across group size"); ++ ++ std::function&)> runFunc = ++ [opts, this](std::unique_ptr& entry) { ++ auto srcdata = (entry->src)[0]; ++ auto dstdata = (entry->dst)[0]; ++ c10::DeviceGuard guard(srcdata.device()); ++ std::unique_lock globalLock(pgGlobalMutex_); ++ MPI_CHECK(MPI_Alltoall( ++ srcdata.data_ptr(), ++ srcdata.numel() / size_, ++ mpiDatatype.at(srcdata.scalar_type()), ++ dstdata.data_ptr(), ++ dstdata.numel() / size_, ++ mpiDatatype.at(dstdata.scalar_type()), ++ pgComm_)); ++ }; ++ std::vector inputTensors = {inputTensor}; ++ std::vector outputTensors = {outputTensor}; ++ auto entry = std::unique_ptr( ++ new WorkEntry(&inputTensors, &outputTensors, std::move(runFunc))); ++ return enqueue(std::move(entry)); ++ } else { ++ // Need alltoallv ++ checkSplitSizes(inputSplitSizes, inputTensor, size_); ++ checkSplitSizes(outputSplitSizes, outputTensor, size_); ++ std::function&)> runFunc = ++ [opts, this, inputSplitSizes, outputSplitSizes]( ++ std::unique_ptr& entry) { ++ auto srcdata = (entry->src)[0]; ++ auto dstdata = (entry->dst)[0]; ++ std::vector send_lengths(size_); ++ std::vector recv_lengths(size_); ++ std::vector send_offsets(size_); ++ std::vector recv_offsets(size_); ++ computeLengthsAndOffsets( ++ inputSplitSizes, srcdata, &send_lengths, &send_offsets); ++ computeLengthsAndOffsets( ++ outputSplitSizes, dstdata, &recv_lengths, &recv_offsets); ++ c10::DeviceGuard guard(srcdata.device()); ++ std::unique_lock globalLock(pgGlobalMutex_); ++ MPI_CHECK(MPI_Alltoallv( ++ srcdata.data_ptr(), ++ send_lengths.data(), ++ send_offsets.data(), ++ mpiDatatype.at(srcdata.scalar_type()), ++ dstdata.data_ptr(), ++ recv_lengths.data(), ++ recv_offsets.data(), ++ mpiDatatype.at(dstdata.scalar_type()), ++ pgComm_)); ++ }; ++ std::vector inputTensors = {inputTensor}; ++ std::vector outputTensors = {outputTensor}; ++ auto entry = std::unique_ptr( ++ new WorkEntry(&inputTensors, &outputTensors, std::move(runFunc))); ++ return enqueue(std::move(entry)); ++ } ++} ++std::shared_ptr ProcessGroupMPI::alltoall( ++ std::vector& outputTensors, ++ std::vector& inputTensors, ++ const AllToAllOptions& opts) { ++ TORCH_CHECK( ++ inputTensors.size() == size_, ++ "Number of input tensors are not equal to group size"); ++ TORCH_CHECK( ++ outputTensors.size() == size_, ++ "Number of output tensors are not equal to group size"); ++ std::function&)> runFunc = ++ [opts, this](std::unique_ptr& entry) { ++ std::vector send_lengths(size_); ++ std::vector recv_lengths(size_); ++ std::vector send_offsets(size_); ++ std::vector recv_offsets(size_); ++ auto srcdata = entry->src; ++ auto dstdata = entry->dst; ++ int64_t src_len = ++ computeLengthsAndOffsets(srcdata, &send_lengths, &send_offsets); ++ int64_t dst_len = ++ computeLengthsAndOffsets(dstdata, &recv_lengths, &recv_offsets); ++ std::vector send_lengthsL( ++ send_lengths.begin(), send_lengths.end()); ++ std::vector recv_lengthsL( ++ recv_lengths.begin(), recv_lengths.end()); ++ at::Tensor srcFlatData = at::empty({src_len}, srcdata[0].options()); ++ at::Tensor dstFlatData = at::empty({dst_len}, dstdata[0].options()); ++ auto srcFlatDataSplits = ++ srcFlatData.split_with_sizes(c10::IntArrayRef(send_lengthsL), 0); ++ for (int i = 0; i < size_; i++) { ++ srcFlatDataSplits[i].copy_(srcdata[i].view({-1})); ++ } ++ c10::DeviceGuard guard1(srcdata[0].device()); ++ std::unique_lock globalLock(pgGlobalMutex_); ++ MPI_CHECK(MPI_Alltoallv( ++ srcFlatData.data_ptr(), ++ send_lengths.data(), ++ send_offsets.data(), ++ mpiDatatype.at(srcdata[0].scalar_type()), ++ dstFlatData.data_ptr(), ++ recv_lengths.data(), ++ recv_offsets.data(), ++ mpiDatatype.at(dstdata[0].scalar_type()), ++ pgComm_)); ++ ++ auto dstFlatDataSplits = ++ dstFlatData.split_with_sizes(c10::IntArrayRef(recv_lengthsL), 0); ++ for (int i = 0; i < size_; i++) { ++ dstdata[i].view({-1}).copy_(dstFlatDataSplits[i]); ++ } ++ }; ++ auto entry = std::unique_ptr( ++ new WorkEntry(&inputTensors, &outputTensors, std::move(runFunc))); ++ return enqueue(std::move(entry)); ++} ++ + std::shared_ptr ProcessGroupMPI::send( + std::vector& tensors, + int dstRank, +diff --git a/torch/lib/c10d/ProcessGroupMPI.hpp b/torch/lib/c10d/ProcessGroupMPI.hpp +index eb42c11..79648f3 100644 +--- a/torch/lib/c10d/ProcessGroupMPI.hpp ++++ b/torch/lib/c10d/ProcessGroupMPI.hpp +@@ -155,6 +155,18 @@ class ProcessGroupMPI : public ProcessGroup { + std::vector>& inputTensors, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + ++ std::shared_ptr alltoall_base( ++ at::Tensor& outputTensor, ++ at::Tensor& inputTensor, ++ std::vector& outputSplitSizes, ++ std::vector& inputSplitSizes, ++ const AllToAllOptions& opts = AllToAllOptions()) override; ++ ++ std::shared_ptr alltoall( ++ std::vector& outputTensors, ++ std::vector& inputTensors, ++ const AllToAllOptions& opts = AllToAllOptions()) override; ++ + std::shared_ptr send( + std::vector& tensors, + int dstRank, +diff --git a/torch/lib/c10d/Types.hpp b/torch/lib/c10d/Types.hpp +index 335f4c5..03b2e59 100644 +--- a/torch/lib/c10d/Types.hpp ++++ b/torch/lib/c10d/Types.hpp +@@ -57,6 +57,10 @@ struct ReduceScatterOptions { + std::chrono::milliseconds timeout = kUnsetTimeout; + }; + ++struct AllToAllOptions { ++ std::chrono::milliseconds timeout = kUnsetTimeout; ++}; ++ + struct BarrierOptions { + std::chrono::milliseconds timeout = kUnsetTimeout; + };