From f9f3a712cd9f1c0b49754291e2cdef77a55c57b6 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Tue, 28 Nov 2023 18:55:07 +0000 Subject: [PATCH] Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token --- test/test_torch_distributed_xla_backend.py | 11 ++++--- torch_xla/core/xla_model.py | 16 +++++----- torch_xla/csrc/cross_replica_reduces.cpp | 35 ++++++++++------------ torch_xla/csrc/cross_replica_reduces.h | 11 ++++--- torch_xla/csrc/ops/all_gather.cpp | 1 - torch_xla/csrc/ops/reduce_scatter.cpp | 16 ++++++---- 6 files changed, 48 insertions(+), 42 deletions(-) diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index 2bc082074b8..47d0ae72139 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -14,12 +14,14 @@ from datetime import timedelta + def get_process_group_xla(rank, size): pg_xla_creator = dist.Backend._plugins['XLA'].creator_fn pg_xla = pg_xla_creator( prefix_store=None, rank=rank, size=size, timeout=timedelta(minutes=1)) return pg_xla + def hlo_matches(hlo, expected_pattern, match_times=1): matches = re.findall(expected_pattern, hlo) assert len(list(matches)) == match_times, hlo @@ -104,10 +106,10 @@ def test_allgather_coalesced(self): output_tensors2 = [torch.zeros_like(tensor2)] * 8 # because we set os.environ[xenv.WORLD_SIZE] = '1', here the outputs' # shapes will be same as the inputs' shapes. + # Ex: %all-gather.26 = (s64[2]{0}, s64[5]{0}) all-gather(s64[2]{0} %get-tuple-element.24, s64[5]{0} %get-tuple-element.25), replica_groups={}, dimensions={0} all_gather_pattern = ( - r'%all-gather\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}, s64\[]\) ' - r'all-gather\(s64\[2]\{0} %.+\.\d+, s64\[5]\{0} %.+\.\d+, ' - r's64\[] %.+\.\d+\)') + r'%all-gather\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}\) ' + r'all-gather\(s64\[2]\{0} %.+\.\d+, s64\[5]\{0} %.+\.\d+\)') pg_xla.allgather_coalesced([output_tensors, output_tensors2], [tensor, tensor2]) hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors) @@ -134,7 +136,8 @@ def test_reduce_scatter(self): hlo = torch_xla._XLAC._get_xla_tensors_hlo([output]) hlo_matches(hlo, reduce_scatter_pattern) - @skipIf(xr.device_type() == 'CPU', "UNIMPLEMENTED: ReduceScatter is not implemented on CPU.") + @skipIf(xr.device_type() == 'CPU', + "UNIMPLEMENTED: ReduceScatter is not implemented on CPU.") def test_reduce_scatter_coalesced(self): device = xm.xla_device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index e3dcf6d24fe..f66c99d1c43 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -582,8 +582,8 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) return output - result = torch_xla._XLAC._xla_all_gather(value, dim, shard_count, - groups or [], pin_layout) + result = torch_xla._XLAC._xla_all_gather(value, dim, shard_count, groups or + [], pin_layout) return result # Now the input should be a list of Tensors. @@ -594,8 +594,8 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): result = torch_xla._XLAC._xla_all_gather_coalesced(value, token, dim, shard_count, groups or [], pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1]) - return result[0] + torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) + return result[:-1] def all_to_all(value, @@ -773,14 +773,14 @@ def reduce_scatter(reduce_type, new_token = torch_xla._XLAC._xla_reduce_scatter_out( reduce_type, output, input, token, scale, scatter_dim, shard_count, groups or [], pin_layout) - devctx.all_reduce_token = new_token + torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) return output result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token, scale, scatter_dim, shard_count, groups or [], pin_layout) - devctx.all_reduce_token = result[1] + torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1]) return result[0] # Now the input should be a list of Tensors. @@ -801,7 +801,9 @@ def reduce_scatter(reduce_type, result = torch_xla._XLAC._xla_reduce_scatter_coalesced( reduce_type, output or [], input, token, scale, scatter_dim, shard_count, groups or [], pin_layout) - devctx.all_reduce_token = result[-1] + #torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1]) + #return result[0] + torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) return result[:-1] diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index 48ce47bb1bc..f70c8dbccf6 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -210,10 +210,11 @@ AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token, return {reduce_result, token_handler.GetNewToken(reduce_result)}; } -AllGatherResult BuildAllGather( - absl::Span inputs, xla::XlaOp token, int64_t dim, - int64_t shard_count, const std::vector>& groups, - bool pin_layout) { +AllGatherResult BuildAllGather(absl::Span inputs, + xla::XlaOp token, int64_t dim, + int64_t shard_count, + const std::vector>& groups, + bool pin_layout) { std::vector cc_groups = CreateReduceGroups(groups); TokenHandler token_handler(token); // TODO: We use pseudo-tokens ATM, which are real values. This need to be @@ -234,13 +235,12 @@ AllGatherResult BuildAllGather( xla::AllGather(xla::Tuple(inputs[0].builder(), type_ctx.second.ops), dim, shard_count, cc_groups); } - if (type_ctx.second.indices.size() > 1) { + if (type_ctx.second.indices.size() > 1) { for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) { size_t op_idx = type_ctx.second.indices[i]; result[op_idx] = xla::GetTupleElement(all_gather_result, i); } - } - else { + } else { result[0] = all_gather_result; } } @@ -285,22 +285,18 @@ RecvResult BuildRecvWithToken(xla::XlaOp token, const xla::Shape& recv_shape, return {result, new_token}; } -std::vector BuildReduceScatter( +ReduceScatterResult BuildReduceScatter( AllReduceType reduce_type, absl::Span inputs, xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count, const std::vector>& groups, bool pin_layout) { std::vector cc_groups = CreateReduceGroups(groups); + TokenHandler token_handler(token); // TODO: We use pseudo-tokens ATM, which are real values. This need to be // switched to use the real XLA Token once support has been added to XLA // ReduceScatter(). - xla::XlaOp chained_token = token; ReduceContext cc_ctx = GetReduceContext(inputs); std::vector result(inputs.size()); for (auto& type_ctx : cc_ctx.contexts) { - xla::XlaOp token_op = MaybeConvertTo(chained_token, type_ctx.first); - type_ctx.second.ops.push_back(token_op); - type_ctx.second.operand_shapes.push_back( - ShapeHelper::ShapeOfXlaOp(token_op)); xla::XlaOp reduce_result; if (pin_layout) { reduce_result = xla::ReduceScatter( @@ -317,7 +313,12 @@ std::vector BuildReduceScatter( } for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) { size_t op_idx = type_ctx.second.indices[i]; - xla::XlaOp gte = xla::GetTupleElement(reduce_result, i); + xla::XlaOp gte; + if (ShapeHelper::ShapeOfXlaOp(reduce_result).rank() == 0) { + gte = xla::GetTupleElement(reduce_result, i); + } else { + gte = reduce_result; + } if (scale != 1.0) { xla::XlaOp scaling_value = XlaHelpers::ScalarValue( scale, type_ctx.second.operand_shapes[i].element_type(), @@ -326,12 +327,8 @@ std::vector BuildReduceScatter( } result[op_idx] = gte; } - chained_token = - xla::GetTupleElement(reduce_result, type_ctx.second.indices.size()); } - result.push_back( - MaybeConvertTo(chained_token, XlaHelpers::TypeOfXlaOp(token))); - return result; + return {result, token_handler.GetNewToken(result[0])}; } // moved from torch_xla/csrc/ops/all_reduce.cpp diff --git a/torch_xla/csrc/cross_replica_reduces.h b/torch_xla/csrc/cross_replica_reduces.h index b6e8b59f69d..d0692080b35 100644 --- a/torch_xla/csrc/cross_replica_reduces.h +++ b/torch_xla/csrc/cross_replica_reduces.h @@ -61,10 +61,10 @@ AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token, const std::vector>& groups, bool pin_layout); -AllGatherResult BuildAllGather( - absl::Span, xla::XlaOp token, int64_t dim, - int64_t shard_count, const std::vector>& groups, - bool pin_layout); +AllGatherResult BuildAllGather(absl::Span, xla::XlaOp token, + int64_t dim, int64_t shard_count, + const std::vector>& groups, + bool pin_layout); CollectivePermuteResult BuildCollectivePermute( xla::XlaOp input, xla::XlaOp token, @@ -76,8 +76,7 @@ SendResult BuildSendWithToken(xla::XlaOp input, xla::XlaOp token, RecvResult BuildRecvWithToken(xla::XlaOp token, const xla::Shape& recv_shape, int64_t channel_id); -//ReduceScatterResult BuildReduceScatter( -std::vector BuildReduceScatter( +ReduceScatterResult BuildReduceScatter( AllReduceType reduce_type, absl::Span inputs, xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count, const std::vector>& groups, bool pin_layout); diff --git a/torch_xla/csrc/ops/all_gather.cpp b/torch_xla/csrc/ops/all_gather.cpp index d5368d8b7bb..12aeab002ac 100644 --- a/torch_xla/csrc/ops/all_gather.cpp +++ b/torch_xla/csrc/ops/all_gather.cpp @@ -32,7 +32,6 @@ xla::Shape NodeOutputShape(c10::ArrayRef inputs, } input_shapes.emplace_back(GetXlaShape(token)); return InferOutputShape(input_shapes, shape_fn); - } } // namespace diff --git a/torch_xla/csrc/ops/reduce_scatter.cpp b/torch_xla/csrc/ops/reduce_scatter.cpp index 941888939f6..a1f2c912457 100644 --- a/torch_xla/csrc/ops/reduce_scatter.cpp +++ b/torch_xla/csrc/ops/reduce_scatter.cpp @@ -18,10 +18,15 @@ xla::Shape NodeOutputShape(AllReduceType reduce_type, const std::vector>& groups, bool pin_layout) { auto shape_fn = [&](absl::Span operands) -> xla::XlaOp { - std::vector result = BuildReduceScatter( + ReduceScatterResult result = BuildReduceScatter( reduce_type, operands.subspan(0, operands.size() - 1), operands.back(), scale, scatter_dim, shard_count, groups, pin_layout); - return xla::Tuple(operands[0].builder(), result); + std::vector outputs; + for (size_t i = 0; i < result.result.size(); ++i) { + outputs.emplace_back(result.result[i]); + } + outputs.emplace_back(result.token); + return xla::Tuple(operands[0].builder(), outputs); }; std::vector input_shapes; for (const auto& input : inputs) { @@ -70,10 +75,11 @@ XlaOpVector ReduceScatter::Lower(LoweringContext* loctx) const { inputs.push_back(loctx->GetOutputOp(operand_list[i])); } xla::XlaOp token = loctx->GetOutputOp(operand_list.back()); - return ReturnOps( + ReduceScatterResult result = BuildReduceScatter(reduce_type_, inputs, token, scale_, scatter_dim_, - shard_count_, groups_, pin_layout_), - loctx); + shard_count_, groups_, pin_layout_); + result.result.push_back(result.token); + return ReturnOps(result.result, loctx); } std::string ReduceScatter::ToString() const {