Skip to content

Commit

Permalink
Support dist.all_to_all_single
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed Sep 24, 2024
1 parent e657c87 commit d772796
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 3 deletions.
38 changes: 37 additions & 1 deletion test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_all_to_all(self, pin_layout):
list(range(world_size))]])


@absltest.skipIf(lambda: tpu.num_logical_cores_per_chip() >= 2,
@absltest.skipIf(tpu.num_logical_cores_per_chip() >= 2,
"Dynamo not supported on TPU v2/v3")
class TestDistCollectiveOpsTpu(parameterized.TestCase):
"""Test for collective ops from torch.distributed"""
Expand Down Expand Up @@ -246,6 +246,32 @@ def callable(output, input):
assert 'xla::reduce_scatter_tensor' in met.counter_names()
return output.cpu()

@staticmethod
def _all_to_all_single(use_dynamo: bool):
met.clear_all()
dist.init_process_group("xla", init_method='xla://')
device = xm.xla_device()

def callable(output, input):
dist.all_to_all_single(output, input)
return output

# check https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3880
# for input and output tensor example
tensor_in = torch.tensor(
[xr.local_ordinal()] * tpu.num_expected_global_devices(),
dtype=torch.float,
device=device)
tensor_out = torch.zeros_like(tensor_in)
f = torch.compile(callable, backend='openxla') if use_dynamo else callable
output = f(tensor_out, tensor_in)
torch_xla.sync()
if not use_dynamo:
assert 'xla::AllToAll' in met.counter_names()
else:
assert 'xla::all_to_all_single' in met.counter_names()
return output.cpu()

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_all_reduce(self, use_dynamo):
results = pjrt.run_multiprocess(self._all_reduce, use_dynamo=use_dynamo)
Expand Down Expand Up @@ -287,6 +313,16 @@ def test_reduce_scatter(self, use_dynamo):
for index, val in results.items():
torch.testing.assert_close(val, expected[index])

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_all_to_all_single(self, use_dynamo):
results = pjrt.run_multiprocess(
self._all_to_all_single, use_dynamo=use_dynamo)
expected = torch.arange(
tpu.num_expected_global_devices(), dtype=torch.float)
# Note: all_to_all xla op does not honor the order of the all_to_all.
for _, val in results.items():
self.assertTrue(torch.allclose(val.sort().values, expected.sort().values))


if __name__ == '__main__':
absltest.main()
39 changes: 39 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/tensor_methods.h"
Expand Down Expand Up @@ -309,6 +310,44 @@ AllGatherResultCoalesced BuildAllGatherCoalesced(
return {result, token_handler.GetNewToken(result[0])};
}

at::Tensor all_to_all_single(const at::Tensor& input,
std::vector<int64_t> output_split_sizes,
std::vector<int64_t> input_split_sizes,
std::string group_name) {
// this basically is the code copy from
// init_python_bindings.cpp:_xla_all_to_all
TORCH_LAZY_FN_COUNTER("xla::");
if (output_split_sizes.size() != 0 && input_split_sizes.size() != 0) {
for (size_t i = 0; i < input_split_sizes.size(); i++) {
if (input_split_sizes[i] != 1)
throw std::runtime_error(
"torch_xla does not support arbitrary split sizes for all_to_all");
}
}
bool pin_layout = false;
const torch::lazy::Value& token =
GetAllReduceToken(bridge::GetCurrentDevice());
int64_t split_count = runtime::GetComputationClient()->GetAllDevices().size();
std::vector<int64_t> all_groups(split_count);
std::iota(all_groups.begin(), all_groups.end(), 0);
XLATensorPtr result_ptr;
torch::lazy::Value new_token;
std::tie(result_ptr, new_token) =
tensor_methods::all_to_all(bridge::GetXlaTensor(input), token, 0, 0,
split_count, {all_groups}, pin_layout);
at::Tensor result = bridge::AtenFromXlaTensor(std::move(result_ptr));

at::Tensor result_with_grad = torch::autograd::make_variable(
result, /*requires_grad=*/input.requires_grad());
SetAllReduceToken(bridge::GetCurrentDevice(),
std::make_shared<torch::lazy::Value>(new_token));
return result_with_grad;
}

TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) {
m.impl("all_to_all_single", all_to_all_single);
}

CollectivePermuteResult BuildCollectivePermute(
xla::XlaOp input, xla::XlaOp token,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs) {
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> AllToAll(
const at::Tensor& input, const std::shared_ptr<torch::lazy::Value>& token,
int64_t split_dimension, int64_t concat_dimension, int64_t split_count,
const std::vector<std::vector<int64_t>>& replica_groups, bool pin_layout) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr result;
torch::lazy::Value new_token;
std::tie(result, new_token) = tensor_methods::all_to_all(
Expand Down
15 changes: 13 additions & 2 deletions torch_xla/distributed/xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,19 @@ def allreduce_coalesced(self, *args):
def alltoall(self, *args):
raise NotImplementedError

def alltoall_base(self, *args):
raise NotImplementedError
# handle the nondynamo path when call torch.distributed.all_to_all_single
# call from https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3996
# Note for pytorch, the split/concat dimension is always 0, while for XLA alltoall,
# we can't specify different split sizes.
def alltoall_base(self, output, input, output_split_sizes, input_split_sizes,
opts):
assert (output_split_sizes is None or len(output_split_sizes) == 0) and \
(input_split_sizes is None or len(input_split_sizes) == 0), \
"XLA doesn't support specifying non-empty output_split_sizes and input_split_sizes"
split_count = xr.world_size()
result = xm.all_to_all(input, 0, 0, split_count, pin_layout=False)
output.copy_(result)
return _ret_work(output)

def gather(self, *args):
raise NotImplementedError
Expand Down

0 comments on commit d772796

Please sign in to comment.