From 79586ba8948f753a879d3887845096fd69a06ed0 Mon Sep 17 00:00:00 2001 From: Patrick Toulme <135739773+ptoulme-aws@users.noreply.github.com> Date: Wed, 27 Nov 2024 05:26:57 -0800 Subject: [PATCH] PR #18988: [WhileLoopAllReduceCodeMotion] Support convert and transpose ops in setup passes. Imported from GitHub PR https://github.com/openxla/xla/pull/18988 WhileLoopAllReduceCodeMotion does not support three very common patterns in Jax models. ``` add(transpose(convert(reduce-scatter)), buffer) add(transpose(reduce-scatter()), buffer) add(convert(reduce-scatter())), buffer) add(transpose(convert(all-reduce)), buffer) add(transpose(all-reduce()), buffer) add(convert(all-reduce())), buffer) ``` This PR adds support for running two optional setup passes before WhileLoopAllReduceCodeMotion which will seek to setup the ``` add(all-reduce/reduce-scatter(), buffer) ``` pattern. This PR adds tests that show that without this PR - the patterns above cannot be code motioned. Note these patterns are extremely prevalent with mixed precision training, and FP32 gradient accumulation buffers. Copybara import of the project: -- 7b7d22c2db31bcffc9e47bc4d94b95183932afa3 by ptoulme-aws : [WhileLoopAllReduceCodeMotion] Support convert and transpose ops in setup passes. Merging this change closes #18988 PiperOrigin-RevId: 700664801 --- .../xla/xla/hlo/transforms/collectives/BUILD | 31 ++ ...while_loop_all_reduce_code_motion_setup.cc | 288 ++++++++++++ .../while_loop_all_reduce_code_motion_setup.h | 70 +++ ..._loop_all_reduce_code_motion_setup_test.cc | 411 ++++++++++++++++ third_party/xla/xla/service/BUILD | 5 + .../while_loop_all_reduce_code_motion.cc | 15 + .../while_loop_all_reduce_code_motion.h | 10 +- .../while_loop_all_reduce_code_motion_test.cc | 437 ++++++++++++++++++ 8 files changed, 1265 insertions(+), 2 deletions(-) create mode 100644 third_party/xla/xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup.cc create mode 100644 third_party/xla/xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup.h create mode 100644 third_party/xla/xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup_test.cc diff --git a/third_party/xla/xla/hlo/transforms/collectives/BUILD b/third_party/xla/xla/hlo/transforms/collectives/BUILD index 44f2649df4d8ec..026d6019654bb2 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/BUILD +++ b/third_party/xla/xla/hlo/transforms/collectives/BUILD @@ -396,3 +396,34 @@ xla_cc_test( "@local_tsl//tsl/platform:statusor", ], ) + +cc_library( + name = "while_loop_all_reduce_code_motion_setup", + srcs = ["while_loop_all_reduce_code_motion_setup.cc"], + hdrs = ["while_loop_all_reduce_code_motion_setup.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:op_expander_pass", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "while_loop_all_reduce_code_motion_setup_test", + srcs = ["while_loop_all_reduce_code_motion_setup_test.cc"], + deps = [ + ":while_loop_all_reduce_code_motion_setup", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) diff --git a/third_party/xla/xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup.cc b/third_party/xla/xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup.cc new file mode 100644 index 00000000000000..5cf2d9441be7e7 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup.cc @@ -0,0 +1,288 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup.h" + +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_creation_utils.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +bool ReorderReduceTranspose::InstructionMatchesPattern( + HloInstruction* instruction) { + // Instruction must be in while loop body. + if (!instruction->parent()->IsWhileBodyComputation()) { + return false; + } + // Search for Reduce Scatter Transpose pairs with optional convert in between + if (instruction->opcode() != HloOpcode::kTranspose) { + return false; + } + + HloInstruction* operand = instruction->mutable_operand(0); + + // Check if the operand is a convert instruction + if (operand->opcode() == HloOpcode::kConvert) { + operand = operand->mutable_operand(0); + } + + // Transpose operand is ReduceScatter + if (operand->opcode() != HloOpcode::kReduceScatter) { + return false; + } + + VLOG(2) << "Found Reduce Scatter (Convert) Transpose Pair:" + << operand->ToString() << "\n" + << instruction->ToString(); + + if (operand->operand_count() != 1) { + VLOG(2) << "Reject Reduce Scatter (Convert) Transpose Pair because Reduce " + "Scatter " + << "has operand count " << operand->operand_count() + << " more than 1 supported by this pass"; + return false; + } + if (instruction->user_count() == 0) { + return false; + } + + // RepeatedTransformers case + // reduce-scatter->transpose->reshape->dynamic-update-slice + if (instruction->users()[0]->opcode() == HloOpcode::kReshape) { + // Look for the dynamic update slice + auto* reshape = instruction->users()[0]; + if (reshape->user_count() == 0) { + return false; + } + return reshape->users()[0]->opcode() == HloOpcode::kDynamicUpdateSlice; + } + + // Check if the Transpose is used in an Add operation + if (instruction->users()[0]->opcode() != HloOpcode::kAdd) { + return false; + } + + HloInstruction* add_instruction = instruction->users()[0]; + + // Check if the first or second operand of the Add is a GetTupleElement whose + // operand is a Parameter + HloInstruction* second_operand = + add_instruction->operand(0)->opcode() == HloOpcode::kGetTupleElement + ? add_instruction->mutable_operand(0) + : add_instruction->mutable_operand(1); + if (second_operand->opcode() != HloOpcode::kGetTupleElement) { + return false; + } + HloInstruction* gte_operand = second_operand->mutable_operand(0); + if (gte_operand->opcode() != HloOpcode::kParameter) { + return false; + } + return true; +} + +absl::StatusOr ReorderReduceTranspose::ExpandInstruction( + HloInstruction* instruction) { + auto* transpose = Cast(instruction); + HloInstruction* operand = instruction->mutable_operand(0); + + // Check if the operand is a convert instruction + bool has_convert = operand->opcode() == HloOpcode::kConvert; + auto* reduce_scatter = + has_convert + ? Cast(operand->mutable_operand(0)) + : Cast(operand); + + // Create a new Convert instruction if the original pattern had one + HloInstruction* new_convert = nullptr; + if (has_convert) { + new_convert = + instruction->parent()->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType( + reduce_scatter->mutable_operand(0)->shape(), + operand->shape().element_type()), + reduce_scatter->mutable_operand(0))); + } + + // Create a new Transpose instruction that uses the same dimension + // for permutation as before, but on the converted operand (if applicable) + // or the original reduce-scatter operand. + TF_ASSIGN_OR_RETURN( + auto* new_transpose, + MakeTransposeHlo( + has_convert ? new_convert : reduce_scatter->mutable_operand(0), + transpose->dimensions())); + + // Create a new ReduceScatter instruction that uses the same replica + // groups as before, but on the new transpose. The scatter dimension has + // now changed based on the transpose, so find it through the transpose + // permutation. + int64_t new_scatter_dim = -1; + for (int i = 0; i < transpose->shape().rank(); i++) { + if (transpose->dimensions()[i] == reduce_scatter->scatter_dimension()) { + new_scatter_dim = i; + break; + } + } + + return instruction->parent()->AddInstruction( + HloInstruction::CreateReduceScatter( + transpose->shape(), {new_transpose}, + reduce_scatter->called_computations()[0], + reduce_scatter->replica_groups(), reduce_scatter->constrain_layout(), + reduce_scatter->channel_id(), reduce_scatter->use_global_device_ids(), + new_scatter_dim)); +} + +bool ReorderConvertReduceAdd::InstructionMatchesPattern( + HloInstruction* instruction) { + // Instruction must be in while loop body. + if (!instruction->parent()->IsWhileBodyComputation()) { + return false; + } + // Check if the instruction is an add operation + if (instruction->opcode() != HloOpcode::kAdd) { + return false; + } + + // Check if one of the operands is a convert operation + HloInstruction* convert_operand = nullptr; + HloInstruction* get_tuple_element_operand = nullptr; + for (HloInstruction* operand : instruction->operands()) { + if (operand->opcode() == HloOpcode::kConvert) { + convert_operand = operand; + } else if (operand->opcode() == HloOpcode::kGetTupleElement) { + get_tuple_element_operand = operand; + } + } + if (convert_operand == nullptr || get_tuple_element_operand == nullptr) { + return false; + } + + // Check if the operand of the convert operation is a reduce-scatter or + // all-reduce + HloInstruction* reduce_op_operand = convert_operand->mutable_operand(0); + if (reduce_op_operand->opcode() != HloOpcode::kReduceScatter && + reduce_op_operand->opcode() != HloOpcode::kAllReduce) { + return false; + } + // Check if the reduce_op_operand is a reduce-scatter and + // enable_reduce_scatter_ is true. + if (!enable_reduce_scatter_ && + reduce_op_operand->opcode() == HloOpcode::kReduceScatter) { + return false; + } + + // Check if the get-tuple-element instruction is operating on a parameter + // tuple + HloInstruction* tuple_operand = get_tuple_element_operand->mutable_operand(0); + if (tuple_operand->opcode() != HloOpcode::kParameter) { + return false; + } + + VLOG(2) << "Found pattern: reduce-scatter/all-reduce, convert, add, with " + "get-tuple-element on parameter tuple"; + return true; +} + +absl::StatusOr ReorderConvertReduceAdd::ExpandInstruction( + HloInstruction* instruction) { + VLOG(2) << "Entering ExpandInstruction"; + + // Get the add, convert, and reduce-scatter/all-reduce instructions + HloInstruction* add = instruction; + HloInstruction* convert = nullptr; + HloInstruction* other_operand = nullptr; + for (HloInstruction* operand : add->operands()) { + if (operand->opcode() == HloOpcode::kConvert) { + convert = operand; + } else { + other_operand = operand; + } + } + // Pattern matched in `InstructionMatchesPattern`. + CHECK(convert != nullptr && other_operand != nullptr); + HloInstruction* reduce_op = convert->mutable_operand(0); + + VLOG(2) << "Found add: " << add->ToString(); + VLOG(2) << "Found convert: " << convert->ToString(); + VLOG(2) << "Found reduce_op: " << reduce_op->ToString(); + VLOG(2) << "Found other_operand: " << other_operand->ToString(); + + // Create a new convert instruction with the reduce-scatter/all-reduce operand + PrimitiveType new_data_type = convert->shape().element_type(); + HloInstruction* new_convert = + instruction->parent()->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(reduce_op->operand(0)->shape(), + new_data_type), + reduce_op->mutable_operand(0))); + + VLOG(2) << "Created new_convert: " << new_convert->ToString(); + + // Create a new reduce-scatter/all-reduce instruction with the converted data + // type + HloInstruction* new_reduce_op; + if (reduce_op->opcode() == HloOpcode::kReduceScatter) { + auto* reduce_scatter = Cast(reduce_op); + Shape new_reduce_scatter_shape = + ShapeUtil::ChangeElementType(reduce_scatter->shape(), new_data_type); + + new_reduce_op = instruction->parent()->AddInstruction( + HloInstruction::CreateReduceScatter( + new_reduce_scatter_shape, {new_convert}, + reduce_scatter->called_computations()[0], + reduce_scatter->replica_groups(), + reduce_scatter->constrain_layout(), reduce_scatter->channel_id(), + reduce_scatter->use_global_device_ids(), + reduce_scatter->scatter_dimension())); + VLOG(2) << "Created new_reduce_op (ReduceScatter): " + << new_reduce_op->ToString(); + } else { + auto* all_reduce = Cast(reduce_op); + Shape new_all_reduce_shape = + ShapeUtil::ChangeElementType(all_reduce->shape(), new_data_type); + + new_reduce_op = + instruction->parent()->AddInstruction(HloInstruction::CreateAllReduce( + new_all_reduce_shape, {new_convert}, + all_reduce->called_computations()[0], all_reduce->replica_groups(), + all_reduce->constrain_layout(), all_reduce->channel_id(), + all_reduce->use_global_device_ids())); + VLOG(2) << "Created new_reduce_op (AllReduce): " + << new_reduce_op->ToString(); + } + + // Create a new add instruction with the new reduce-scatter/all-reduce and the + // other operand + HloInstruction* new_add = + instruction->parent()->AddInstruction(HloInstruction::CreateBinary( + add->shape(), HloOpcode::kAdd, new_reduce_op, other_operand)); + + VLOG(2) << "Created new_add: " << new_add->ToString(); + VLOG(2) << "Leaving ExpandInstruction"; + + return new_add; +} + +} // namespace xla diff --git a/third_party/xla/xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup.h b/third_party/xla/xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup.h new file mode 100644 index 00000000000000..9ba28f9e33c75a --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup.h @@ -0,0 +1,70 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_COLLECTIVES_WHILE_LOOP_ALL_REDUCE_CODE_MOTION_SETUP_H_ +#define XLA_HLO_TRANSFORMS_COLLECTIVES_WHILE_LOOP_ALL_REDUCE_CODE_MOTION_SETUP_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" + +namespace xla { + +// Reorder the sequence of reduce-scatter, convert, transpose, and add +// operations. This transformation changes the pattern from: +// add(transpose(convert(reduce-scatter(operand))), get-tuple(parameter(0)) +// add(transpose(reduce-scatter(operand)), get-tuple(parameter(0)) +// to: +// add(reduce-scatter(transpose(convert(operand))), get-tuple(parameter(0)) +// add(reduce-scatter(transpose(operand)), get-tuple(parameter(0)) +class ReorderReduceTranspose : public OpExpanderPass { + public: + absl::string_view name() const override { return "reorder-reduce-transpose"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +// Reorder the reduce-scatter/all-reduce and convert operations followed +// by an add. This transformation changes the pattern from: +// add(convert(reduce-scatter(operand)), get-tuple(parameter(0))) +// add(convert(all-reduce(operand)), get-tuple(parameter(0))) +// to: +// add(reduce-scatter(convert(operand)), get-tuple(parameter(0))) +// add(all-reduce(convert(operand)), get-tuple(parameter(0))) +class ReorderConvertReduceAdd : public OpExpanderPass { + public: + absl::string_view name() const override { + return "reorder-convert-reduce-add"; + } + + // Constructor with optional enable_reduce_scatter parameter + explicit ReorderConvertReduceAdd(bool enable_reduce_scatter = false) + : enable_reduce_scatter_(enable_reduce_scatter) {} + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; + // Enable transformation of reduce-scatter op. + bool enable_reduce_scatter_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_COLLECTIVES_WHILE_LOOP_ALL_REDUCE_CODE_MOTION_SETUP_H_ diff --git a/third_party/xla/xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup_test.cc b/third_party/xla/xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup_test.cc new file mode 100644 index 00000000000000..2f9717ae57628c --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup_test.cc @@ -0,0 +1,411 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup.h" + +#include + +#include +#include +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/utils/hlo_matchers.h" +#include "tsl/platform/statusor.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class ReorderReduceTransposeTest : public HloHardwareIndependentTestBase { + protected: + ReorderReduceTransposeTest() = default; +}; + +TEST_F(ReorderReduceTransposeTest, SimpleReduceScatterTransposeInWhileBody) { + constexpr std::string_view hlo = R"( +HloModule main + +%reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +%while_cond { + %param = (s32[4,4], s32[4,2], s32[4,2]) parameter(0) + ROOT cond = pred[] constant(true) +} + +%while_body { + %param = (s32[4,4], s32[4,2], s32[4,2]) parameter(0) + %gte.0 = s32[4,4] get-tuple-element(%param), index=0 + %gte.1 = s32[4,2] get-tuple-element(%param), index=1 + %reduce_scatter.0 = s32[2,4] reduce-scatter(%gte.0), dimensions={0}, replica_groups={{0,1}}, to_apply=%reduction + %transpose.0 = s32[4,2] transpose(%reduce_scatter.0), dimensions={1,0} + %add.0 = s32[4,2] add(%transpose.0, %gte.1) + ROOT tuple = (s32[4,4], s32[4,2], s32[4,2]) tuple(%gte.0, %add.0, %gte.1) +} + +ENTRY main { + %init_param = (s32[4,4], s32[4,2], s32[4,2]) parameter(0) + ROOT while = (s32[4,4], s32[4,2], s32[4,2]) while(%init_param), condition=%while_cond, body=%while_body +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + ReorderReduceTranspose rrt; + TF_ASSERT_OK_AND_ASSIGN(bool changed, rrt.Run(module.get())); + EXPECT_TRUE(changed); + + // Check that the transpose and reduce-scatter have been reordered inside the + // while body. + HloInstruction* while_inst = module->entry_computation()->root_instruction(); + HloComputation* while_body = while_inst->while_body(); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::GetTupleElement(), + op::Add(op::ReduceScatter(op::Transpose()), + op::GetTupleElement()), + op::GetTupleElement())); +} + +TEST_F(ReorderReduceTransposeTest, + ReduceScatterConvertTransposeNotInWhileBody) { + constexpr std::string_view hlo = R"( +HloModule main + +%reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +ENTRY main { + arg.0 = f32[4,4] parameter(0) + reduce_scatter.0 = f32[2,4] reduce-scatter(arg.0), dimensions={0}, replica_groups={{0,1}}, to_apply=%reduction + convert.0 = s32[2,4] convert(reduce_scatter.0) + ROOT transpose.0 = s32[4,2] transpose(convert.0), dimensions={1,0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + ReorderReduceTranspose rrt; + TF_ASSERT_OK_AND_ASSIGN(bool changed, rrt.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(ReorderReduceTransposeTest, ReduceScatterConvertTransposeInWhileBody) { + constexpr std::string_view hlo = R"( +HloModule main + +%reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +%while_cond { + %param = (f32[4,4], s32[4,2], s32[4,2]) parameter(0) + ROOT cond = pred[] constant(true) +} + +%while_body { + %param = (f32[4,4], s32[4,2], s32[4,2]) parameter(0) + %gte.0 = f32[4,4] get-tuple-element(%param), index=0 + %gte.1 = s32[4,2] get-tuple-element(%param), index=1 + %reduce_scatter.0 = f32[2,4] reduce-scatter(%gte.0), dimensions={0}, replica_groups={{0,1}}, to_apply=%reduction + %convert.0 = s32[2,4] convert(%reduce_scatter.0) + %transpose.0 = s32[4,2] transpose(%convert.0), dimensions={1,0} + %add.0 = s32[4,2] add(%transpose.0, %gte.1) + ROOT tuple = (f32[4,4], s32[4,2], s32[4,2]) tuple(%gte.0, %add.0, %gte.1) +} + +ENTRY main { + %init_param = (f32[4,4], s32[4,2], s32[4,2]) parameter(0) + ROOT while = (f32[4,4], s32[4,2], s32[4,2]) while(%init_param), condition=%while_cond, body=%while_body +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + ReorderReduceTranspose rrt; + TF_ASSERT_OK_AND_ASSIGN(bool changed, rrt.Run(module.get())); + EXPECT_TRUE(changed); + // Check that the transpose, convert, and reduce-scatter have been reordered + // inside the while body. + HloInstruction* while_inst = module->entry_computation()->root_instruction(); + HloComputation* while_body = while_inst->while_body(); + + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::GetTupleElement(), + op::Add(op::ReduceScatter(op::Transpose(op::Convert())), + op::GetTupleElement()), + op::GetTupleElement())); +} + +TEST_F(ReorderReduceTransposeTest, + ReduceScatterTransposeReshapeDynamicUpdateSliceInWhileBody) { + constexpr std::string_view hlo = R"( +HloModule main + +%reduction { + %x = s32[] parameter(0) + %y = s32[] parameter(1) + ROOT %add = s32[] add(s32[] %x, s32[] %y) +} + +%while_cond { + %param = (s32[4,4], s32[8], s32[]) parameter(0) + ROOT cond = pred[] constant(true) +} + +%while_body { + %param = (s32[4,4], s32[8], s32[]) parameter(0) + %gte.0 = s32[4,4] get-tuple-element(%param), index=0 + %gte.1 = s32[8] get-tuple-element(%param), index=1 + %gte.2 = s32[] get-tuple-element(%param), index=2 + %reduce_scatter.0 = s32[2,4] reduce-scatter(%gte.0), dimensions={0}, replica_groups={{0,1}}, to_apply=%reduction + %transpose.0 = s32[4,2] transpose(%reduce_scatter.0), dimensions={1,0} + %reshape.0 = s32[8] reshape(%transpose.0) + %dynamic-update-slice.0 = s32[8] dynamic-update-slice(%gte.1, %reshape.0, %gte.2) + ROOT tuple = (s32[4,4], s32[8], s32[]) tuple(%gte.0, %dynamic-update-slice.0, %gte.2) +} + +ENTRY main { + %init_param = (s32[4,4], s32[8], s32[]) parameter(0) + ROOT while = (s32[4,4], s32[8], s32[]) while(%init_param), condition=%while_cond, body=%while_body +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + ReorderReduceTranspose rrt; + TF_ASSERT_OK_AND_ASSIGN(bool changed, rrt.Run(module.get())); + EXPECT_TRUE(changed); + + // Check that the transpose and reduce-scatter have been reordered inside the + // while body. + HloInstruction* while_inst = module->entry_computation()->root_instruction(); + HloComputation* while_body = while_inst->while_body(); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::GetTupleElement(), + op::DynamicUpdateSlice( + op::GetTupleElement(), + op::Reshape(op::ReduceScatter(op::Transpose())), + op::GetTupleElement()), + op::GetTupleElement())); +} + +class ReorderConvertReduceAddTest : public HloHardwareIndependentTestBase { + protected: + ReorderConvertReduceAddTest() = default; +}; + +TEST_F(ReorderConvertReduceAddTest, SimpleConvertReduceScatterAddInWhileBody) { + constexpr std::string_view hlo = R"( +HloModule main + +%reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +%while_cond { + %param = (f32[4,4], s32[2,4], s32[]) parameter(0) + ROOT cond = pred[] constant(true) +} + +%while_body { + %param = (f32[4,4], s32[2,4], s32[]) parameter(0) + %gte.0 = f32[4,4] get-tuple-element(%param), index=0 + %gte.1 = s32[2,4] get-tuple-element(%param), index=1 + %gte.2 = s32[] get-tuple-element(%param), index=2 + %reduce_scatter.0 = f32[2,4] reduce-scatter(%gte.0), dimensions={0}, replica_groups={{0,1}}, to_apply=%reduction + %convert.0 = s32[2,4] convert(%reduce_scatter.0) + %add.0 = s32[2,4] add(%convert.0, %gte.1) + ROOT tuple = (f32[4,4], s32[2,4], s32[]) tuple(%gte.0, %add.0, %gte.2) +} + +ENTRY main { + %init_param = (f32[4,4], s32[2,4], s32[]) parameter(0) + ROOT while = (f32[4,4], s32[2,4], s32[]) while(%init_param), condition=%while_cond, body=%while_body +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + ReorderConvertReduceAdd rcra(/*enable_reduce_scatter=*/true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, rcra.Run(module.get())); + EXPECT_TRUE(changed); + + // Check that the convert, reduce-scatter, and add have been reordered inside + // the while body. + HloInstruction* while_inst = module->entry_computation()->root_instruction(); + HloComputation* while_body = while_inst->while_body(); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::GetTupleElement(), + op::Add(op::ReduceScatter(op::Convert()), + op::GetTupleElement()), + op::GetTupleElement())); +} + +TEST_F(ReorderConvertReduceAddTest, ConvertAllReduceAddNotInWhileBody) { + constexpr std::string_view hlo = R"( +HloModule main + +%reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +ENTRY main { + arg.0 = f32[4,4] parameter(0) + all_reduce.0 = f32[4,4] all-reduce(arg.0), replica_groups={{0,1}}, to_apply=%reduction + ROOT convert.0 = s32[4,4] convert(all_reduce.0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + ReorderConvertReduceAdd rcra; + TF_ASSERT_OK_AND_ASSIGN(bool changed, rcra.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(ReorderConvertReduceAddTest, ConvertReduceScatterAddInWhileBody) { + constexpr std::string_view hlo = R"( +HloModule main + +%reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +%while_cond { + %param = (f32[4,4], s32[2,4], s32[]) parameter(0) + ROOT cond = pred[] constant(true) +} + +%while_body { + %param = (f32[4,4], s32[2,4], s32[]) parameter(0) + %gte.0 = f32[4,4] get-tuple-element(%param), index=0 + %gte.1 = s32[2,4] get-tuple-element(%param), index=1 + %gte.2 = s32[] get-tuple-element(%param), index=2 + %reduce_scatter.0 = f32[2,4] reduce-scatter(%gte.0), dimensions={0}, replica_groups={{0,1}}, to_apply=%reduction + %convert.0 = s32[2,4] convert(%reduce_scatter.0) + %add.0 = s32[2,4] add(%convert.0, %gte.1) + ROOT tuple = (f32[4,4], s32[2,4], s32[]) tuple(%gte.0, %add.0, %gte.2) +} + +ENTRY main { + %init_param = (f32[4,4], s32[2,4], s32[]) parameter(0) + ROOT while = (f32[4,4], s32[2,4], s32[]) while(%init_param), condition=%while_cond, body=%while_body +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + ReorderConvertReduceAdd rcra(/*enable_reduce_scatter=*/true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, rcra.Run(module.get())); + EXPECT_TRUE(changed); + + // Check that the convert, reduce-scatter, and add have been reordered inside + // the while body. + HloInstruction* while_inst = module->entry_computation()->root_instruction(); + HloComputation* while_body = while_inst->while_body(); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::GetTupleElement(), + op::Add(op::ReduceScatter(op::Convert()), + op::GetTupleElement()), + op::GetTupleElement())); +} + +TEST_F(ReorderConvertReduceAddTest, DisableReduceScatter) { + constexpr std::string_view hlo = R"( +HloModule main + +%reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +%while_cond { + %param = (f32[4,4], s32[2,4], s32[]) parameter(0) + ROOT cond = pred[] constant(true) +} + +%while_body { + %param = (f32[4,4], s32[2,4], s32[]) parameter(0) + %gte.0 = f32[4,4] get-tuple-element(%param), index=0 + %gte.1 = s32[2,4] get-tuple-element(%param), index=1 + %gte.2 = s32[] get-tuple-element(%param), index=2 + %reduce_scatter.0 = f32[2,4] reduce-scatter(%gte.0), dimensions={0}, replica_groups={{0,1}}, to_apply=%reduction + %convert.0 = s32[2,4] convert(%reduce_scatter.0) + %add.0 = s32[2,4] add(%convert.0, %gte.1) + ROOT tuple = (f32[4,4], s32[2,4], s32[]) tuple(%gte.0, %add.0, %gte.2) +} + +ENTRY main { + %init_param = (f32[4,4], s32[2,4], s32[]) parameter(0) + ROOT while = (f32[4,4], s32[2,4], s32[]) while(%init_param), condition=%while_cond, body=%while_body +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + ReorderConvertReduceAdd rcra(/*enable_reduce_scatter=*/false); + TF_ASSERT_OK_AND_ASSIGN(bool changed, rcra.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(ReorderConvertReduceAddTest, ConvertAllReduceAddInWhileBody) { + constexpr std::string_view hlo = R"( +HloModule main + +%reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +%while_cond { + %param = (f32[2,4], s32[2,4], s32[]) parameter(0) + ROOT cond = pred[] constant(true) +} + +%while_body { + %param = (f32[2,4], s32[2,4], s32[]) parameter(0) + %gte.0 = f32[2,4] get-tuple-element(%param), index=0 + %gte.1 = s32[2,4] get-tuple-element(%param), index=1 + %gte.2 = s32[] get-tuple-element(%param), index=2 + %all_reduce.0 = f32[2,4] all-reduce(%gte.0), replica_groups={{0,1}}, to_apply=%reduction + %convert.0 = s32[2,4] convert(%all_reduce.0) + %add.0 = s32[2,4] add(%convert.0, %gte.1) + ROOT tuple = (f32[2,4], s32[2,4], s32[]) tuple(%gte.0, %add.0, %gte.2) +} + +ENTRY main { + %init_param = (f32[2,4], s32[2,4], s32[]) parameter(0) + ROOT while = (f32[2,4], s32[2,4], s32[]) while(%init_param), condition=%while_cond, body=%while_body +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + ReorderConvertReduceAdd rcra; + TF_ASSERT_OK_AND_ASSIGN(bool changed, rcra.Run(module.get())); + EXPECT_TRUE(changed); + + // Check that the convert, all-reduce, and add have been reordered inside the + // while body. + HloInstruction* while_inst = module->entry_computation()->root_instruction(); + HloComputation* while_body = while_inst->while_body(); + EXPECT_THAT( + while_body->root_instruction(), + op::Tuple(op::GetTupleElement(), + op::Add(op::AllReduce(op::Convert()), op::GetTupleElement()), + op::GetTupleElement())); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 3c2d8350c0e26b..e5d0a131b141c4 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -4852,6 +4852,8 @@ cc_library( "//xla/hlo/analysis:hlo_replication_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms/collectives:while_loop_all_reduce_code_motion_setup", "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -4874,7 +4876,10 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/while_loop_all_reduce_code_motion.cc b/third_party/xla/xla/service/while_loop_all_reduce_code_motion.cc index aac2a22abfe0c5..eb235a6396edd9 100644 --- a/third_party/xla/xla/service/while_loop_all_reduce_code_motion.cc +++ b/third_party/xla/xla/service/while_loop_all_reduce_code_motion.cc @@ -32,6 +32,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/literal_util.h" #include "xla/map_util.h" @@ -961,6 +963,19 @@ absl::StatusOr WhileLoopAllReduceCodeMotion::Run( HloReplicationAnalysis::RunWithPartialReplication( module, /*cross_partition_spmd=*/true)); } + + // Run setup passes that may setup the add(all-reduce/reduce-scatter, + // accumulation_buffer) pattern. + if (run_setup_passes_) { + HloPassPipeline pipeline("while-loop-all-reduce-code-motion-setup"); + if (enable_reduce_scatter_) { + pipeline.AddPass(); + } + pipeline.AddPass( + /*enable_reduce_scatter=*/enable_reduce_scatter_); + TF_RETURN_IF_ERROR(pipeline.Run(module, execution_threads).status()); + } + // The while instruction's parent could be a while body for another while // loop. We recursively sink the all-reduce through nested while loops if // applicable by repeating this process. diff --git a/third_party/xla/xla/service/while_loop_all_reduce_code_motion.h b/third_party/xla/xla/service/while_loop_all_reduce_code_motion.h index 730a06c9e662e2..e3b30c90850df1 100644 --- a/third_party/xla/xla/service/while_loop_all_reduce_code_motion.h +++ b/third_party/xla/xla/service/while_loop_all_reduce_code_motion.h @@ -44,8 +44,10 @@ namespace xla { // a += e class WhileLoopAllReduceCodeMotion : public HloModulePass { public: - explicit WhileLoopAllReduceCodeMotion(bool enable_reduce_scatter = false) - : enable_reduce_scatter_(enable_reduce_scatter) {} + explicit WhileLoopAllReduceCodeMotion(bool enable_reduce_scatter = false, + bool run_setup_passes = false) + : enable_reduce_scatter_(enable_reduce_scatter), + run_setup_passes_(run_setup_passes) {} ~WhileLoopAllReduceCodeMotion() override = default; absl::string_view name() const override { @@ -58,6 +60,10 @@ class WhileLoopAllReduceCodeMotion : public HloModulePass { private: const bool enable_reduce_scatter_; + + // Whether to run passes that may setup the add(all-reduce/reduce-scatter, + // accumulation_buffer) pattern. + const bool run_setup_passes_; }; } // namespace xla diff --git a/third_party/xla/xla/service/while_loop_all_reduce_code_motion_test.cc b/third_party/xla/xla/service/while_loop_all_reduce_code_motion_test.cc index 271b04cd4e2643..b422f2b3bf46ab 100644 --- a/third_party/xla/xla/service/while_loop_all_reduce_code_motion_test.cc +++ b/third_party/xla/xla/service/while_loop_all_reduce_code_motion_test.cc @@ -17,12 +17,17 @@ limitations under the License. #include #include +#include #include +#include #include #include #include +#include +#include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -32,6 +37,7 @@ limitations under the License. #include "xla/service/hlo_verifier.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -1175,5 +1181,436 @@ TEST_F(WhileLoopAllReduceCodeMotionTest, EXPECT_FALSE(simplified_loop); } +// This test checks the add(transpose(reduce-scatter()), buffer) case +// code motions when setup passes are enabled. +TEST_F(WhileLoopAllReduceCodeMotionTest, ReduceScatterTransposeAccumulate) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_reduce_scatter + + %reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) + } + + %while_condition { + %param = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT + } + + %while_body { + %param = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + %gte.2 = f32[4096, 1024] get-tuple-element(%param), index=2 + %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3 + %reduce-scatter = f32[1024, 1024] reduce-scatter(f32[4096, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction, dimensions={0} + %transpose.0 = f32[1024,1024] transpose(%reduce-scatter), dimensions={1,0} + %accumulation = f32[1024, 1024] add(f32[1024, 1024] %transpose.0, f32[1024, 1024] %gte.3) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation) + } + + ENTRY accumulated_all_reduce { + %param.0 = s32[] parameter(0) + %param.1 = f32[4096, 1024] parameter(1) + %constant.0 = s32[] constant(1) + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %while_init = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[4096, 1024] %param.1, f32[1024, 1024] %accumulation_buffer) + ROOT %while = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + (WhileLoopAllReduceCodeMotion{/*enable_reduce_scatter=*/true, + /*run_setup_passes=*/true} + .Run(module.get()))); + ASSERT_TRUE(simplified_loop); + TF_ASSERT_OK( + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true) + .Run(module.get()) + .status()); + HloComputation* entry = module->entry_computation(); + HloInstruction* transformed_while = find_op(entry); + ASSERT_THAT(transformed_while, NotNull()); + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::ReduceScatter()))); + HloInstruction* accumulation_buffer = + transformed_while->mutable_operand(0)->mutable_operand(3); + EXPECT_THAT(accumulation_buffer, op::Constant()); + // Verify that the accumulation buffer's shape changed. + EXPECT_THAT(accumulation_buffer, op::Shape("f32[1024, 4096]")); + auto* moved_reduce_scatter = DynCast( + find_op(entry)); + ASSERT_THAT(moved_reduce_scatter, NotNull()); + EXPECT_THAT(moved_reduce_scatter->operand(0), op::GetTupleElement()); + EXPECT_EQ(DynCast( + moved_reduce_scatter->mutable_operand(0)) + ->tuple_index(), + 3); + EXPECT_THAT(moved_reduce_scatter, op::ReplicaGroups({{0, 1, 2, 3}})); + EXPECT_FALSE(moved_reduce_scatter->constrain_layout()); + EXPECT_TRUE(moved_reduce_scatter->use_global_device_ids()); + HloComputation* reduction_computation = + module->GetComputationWithName("reduction"); + ASSERT_THAT(reduction_computation, NotNull()); + EXPECT_EQ(moved_reduce_scatter->to_apply(), reduction_computation); +} + +// This test checks the add(transose(reduce-scatter()), buffer) case +// does not code motion when setup passes are disabled. +TEST_F(WhileLoopAllReduceCodeMotionTest, + ReduceScatterTransposeAccumulateNoMotion) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_reduce_scatter + + %reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) + } + + %while_condition { + %param = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT + } + + %while_body { + %param = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + %gte.2 = f32[4096, 1024] get-tuple-element(%param), index=2 + %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3 + %reduce-scatter = f32[1024, 1024] reduce-scatter(f32[4096, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction, dimensions={0} + %transpose.0 = f32[1024,1024] transpose(%reduce-scatter), dimensions={1,0} + %accumulation = f32[1024, 1024] add(f32[1024, 1024] %reduce-scatter, f32[1024, 1024] %gte.3) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation) + } + + ENTRY accumulated_all_reduce { + %param.0 = s32[] parameter(0) + %param.1 = f32[4096, 1024] parameter(1) + %constant.0 = s32[] constant(1) + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %while_init = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[4096, 1024] %param.1, f32[1024, 1024] %accumulation_buffer) + ROOT %while = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + (WhileLoopAllReduceCodeMotion{/*enable_reduce_scatter=*/true, + /*run_setup_passes=*/false} + .Run(module.get()))); + ASSERT_FALSE(simplified_loop); +} + +// This test checks the add(transpose(convert(reduce-scatter())), buffer) case +// code motions when setup passes are enabled. +TEST_F(WhileLoopAllReduceCodeMotionTest, + ReduceScatterTransposeConvertAccumulate) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_reduce_scatter + + %reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) + } + + %while_condition { + %param = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT + } + + %while_body { + %param = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + %gte.2 = f32[4096, 1024] get-tuple-element(%param), index=2 + %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3 + %convert.0 = bf16[4096, 1024] convert(f32[4096, 1024] %gte.2) + %reduce-scatter = bf16[1024, 1024] reduce-scatter(bf16[4096, 1024] %convert.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction, dimensions={0} + %convert.1 = f32[1024,1024] convert(bf16[1024, 1024] %reduce-scatter) + %transpose.0 = f32[1024,1024] transpose(%convert.1), dimensions={1,0} + %accumulation = f32[1024, 1024] add(f32[1024, 1024] %transpose.0, f32[1024, 1024] %gte.3) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation) + } + + ENTRY accumulated_all_reduce { + %param.0 = s32[] parameter(0) + %param.1 = f32[4096, 1024] parameter(1) + %constant.0 = s32[] constant(1) + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %while_init = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[4096, 1024] %param.1, f32[1024, 1024] %accumulation_buffer) + ROOT %while = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + (WhileLoopAllReduceCodeMotion{/*enable_reduce_scatter=*/true, + /*run_setup_passes=*/true} + .Run(module.get()))); + ASSERT_TRUE(simplified_loop); + TF_ASSERT_OK( + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true) + .Run(module.get()) + .status()); + HloComputation* entry = module->entry_computation(); + HloInstruction* transformed_while = find_op(entry); + ASSERT_THAT(transformed_while, NotNull()); + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::ReduceScatter()))); + HloInstruction* accumulation_buffer = + transformed_while->mutable_operand(0)->mutable_operand(3); + EXPECT_THAT(accumulation_buffer, op::Constant()); + // Verify that the accumulation buffer's shape changed. + EXPECT_THAT(accumulation_buffer, op::Shape("f32[1024, 4096]")); + auto* moved_reduce_scatter = DynCast( + find_op(entry)); + ASSERT_THAT(moved_reduce_scatter, NotNull()); + EXPECT_THAT(moved_reduce_scatter->operand(0), op::GetTupleElement()); + EXPECT_EQ(DynCast( + moved_reduce_scatter->mutable_operand(0)) + ->tuple_index(), + 3); + EXPECT_THAT(moved_reduce_scatter, op::ReplicaGroups({{0, 1, 2, 3}})); + EXPECT_FALSE(moved_reduce_scatter->constrain_layout()); + EXPECT_TRUE(moved_reduce_scatter->use_global_device_ids()); + HloComputation* reduction_computation = + module->GetComputationWithName("reduction"); + ASSERT_THAT(reduction_computation, NotNull()); + EXPECT_EQ(moved_reduce_scatter->to_apply(), reduction_computation); +} + +// This test checks the add(transpose(convert(reduce-scatter())), buffer) case +// does not code motion when setup passes are disabled. +TEST_F(WhileLoopAllReduceCodeMotionTest, + ReduceScatterTransposeConvertDisabledAccumulate) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_reduce_scatter + + %reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) + } + + %while_condition { + %param = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT + } + + %while_body { + %param = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + %gte.2 = f32[4096, 1024] get-tuple-element(%param), index=2 + %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3 + %convert.0 = bf16[4096, 1024] convert(f32[4096, 1024] %gte.2) + %reduce-scatter = bf16[1024, 1024] reduce-scatter(bf16[4096, 1024] %convert.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction, dimensions={0} + %convert.1 = f32[1024,1024] convert(bf16[1024, 1024] %reduce-scatter) + %transpose.0 = f32[1024,1024] transpose(%reduce-scatter), dimensions={1,0} + %accumulation = f32[1024, 1024] add(f32[1024, 1024] %transpose.0, f32[1024, 1024] %gte.3) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation) + } + + ENTRY accumulated_all_reduce { + %param.0 = s32[] parameter(0) + %param.1 = f32[4096, 1024] parameter(1) + %constant.0 = s32[] constant(1) + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %while_init = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[4096, 1024] %param.1, f32[1024, 1024] %accumulation_buffer) + ROOT %while = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + (WhileLoopAllReduceCodeMotion{/*enable_reduce_scatter=*/true, + /*run_setup_passes=*/false} + .Run(module.get()))); + ASSERT_FALSE(simplified_loop); +} + +// This test checks the add((convert(reduce-scatter()), buffer) case +// code motions when setup passes are enabled. +TEST_F(WhileLoopAllReduceCodeMotionTest, ReduceScatterConvertAccumulate) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_reduce_scatter + + %reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) + } + + %while_condition { + %param = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT + } + + %while_body { + %param = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + %gte.2 = f32[4096, 1024] get-tuple-element(%param), index=2 + %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3 + %convert.0 = bf16[4096, 1024] convert(f32[4096, 1024] %gte.2) + %reduce-scatter = bf16[1024, 1024] reduce-scatter(bf16[4096, 1024] %convert.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction, dimensions={0} + %convert.1 = f32[1024,1024] convert(bf16[1024, 1024] %reduce-scatter) + %accumulation = f32[1024, 1024] add(f32[1024, 1024] %convert.1, f32[1024, 1024] %gte.3) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation) + } + + ENTRY accumulated_all_reduce { + %param.0 = s32[] parameter(0) + %param.1 = f32[4096, 1024] parameter(1) + %constant.0 = s32[] constant(1) + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %while_init = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[4096, 1024] %param.1, f32[1024, 1024] %accumulation_buffer) + ROOT %while = (s32[], s32[], f32[4096, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + (WhileLoopAllReduceCodeMotion{/*enable_reduce_scatter=*/true, + /*run_setup_passes=*/true} + .Run(module.get()))); + ASSERT_TRUE(simplified_loop); + TF_ASSERT_OK( + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true) + .Run(module.get()) + .status()); + HloComputation* entry = module->entry_computation(); + HloInstruction* transformed_while = find_op(entry); + ASSERT_THAT(transformed_while, NotNull()); + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::ReduceScatter()))); + HloInstruction* accumulation_buffer = + transformed_while->mutable_operand(0)->mutable_operand(3); + EXPECT_THAT(accumulation_buffer, op::Constant()); + // Verify that the accumulation buffer's shape changed. + EXPECT_THAT(accumulation_buffer, op::Shape("f32[4096, 1024]")); + auto* moved_reduce_scatter = DynCast( + find_op(entry)); + ASSERT_THAT(moved_reduce_scatter, NotNull()); + EXPECT_THAT(moved_reduce_scatter->operand(0), op::GetTupleElement()); + EXPECT_EQ(DynCast( + moved_reduce_scatter->mutable_operand(0)) + ->tuple_index(), + 3); + EXPECT_THAT(moved_reduce_scatter, op::ReplicaGroups({{0, 1, 2, 3}})); + EXPECT_FALSE(moved_reduce_scatter->constrain_layout()); + EXPECT_TRUE(moved_reduce_scatter->use_global_device_ids()); + HloComputation* reduction_computation = + module->GetComputationWithName("reduction"); + ASSERT_THAT(reduction_computation, NotNull()); + EXPECT_EQ(moved_reduce_scatter->to_apply(), reduction_computation); +} + +// This test checks the add((convert(all-reduce()), buffer) case +// code motions when setup passes are enabled. +TEST_F(WhileLoopAllReduceCodeMotionTest, AllReduceConvertAccumulateUse) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_all_reduce + + %reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) + } + + %while_condition { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT + } + + %while_body { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2 + %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3 + %convert.0 = bf16[1024, 1024] convert(f32[1024, 1024] %gte.2) + %all-reduce = bf16[1024, 1024] all-reduce(bf16[1024, 1024] %convert.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction + %convert.1 = f32[1024,1024] convert(bf16[1024, 1024] %all-reduce) + %accumulation = f32[1024, 1024] add(f32[1024, 1024] %convert.1, f32[1024, 1024] %gte.3) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation) + } + + ENTRY accumulated_all_reduce { + %param.0 = s32[] parameter(0) + %param.1 = f32[1024, 1024] parameter(1) + %constant.0 = s32[] constant(1) + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer) + %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body + %gte_while = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=3 + ROOT %multiply = f32[1024, 1024] multiply(f32[1024, 1024] %gte_while, f32[1024, 1024] %param.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + (WhileLoopAllReduceCodeMotion{/*enable_reduce_scatter=*/true, + /*run_setup_passes=*/true} + .Run(module.get()))); + ASSERT_TRUE(simplified_loop); + TF_ASSERT_OK( + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true) + .Run(module.get()) + .status()); + HloComputation* entry = module->entry_computation(); + HloInstruction* transformed_while = find_op(entry); + + ASSERT_THAT(transformed_while, NotNull()); + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::AllReduce()))); + HloInstruction* new_root = module->entry_computation()->root_instruction(); + ASSERT_THAT(new_root, op::Multiply()); + ASSERT_THAT(new_root->operand(0), op::GetTupleElement()); + ASSERT_THAT(new_root->operand(0)->operand(0), op::Tuple()); + EXPECT_THAT(new_root->operand(0)->operand(0)->operand(3), op::Add()); +} + } // namespace } // namespace xla