From 68d4e6fe68250705bdc2a79c8b998b42d05bbc7c Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Tue, 3 Dec 2024 23:31:11 -0800 Subject: [PATCH] Move ForAllThunk into Thunk base class `ForAllThunk` is currently implemented as a free function. It's a switch statement which handles all thunk types. Each type gets downcasted to the real type for handling the nested buffers. This change makes `ForAllThunk` a virtual function on `Thunk`. `Thunk` provides a default implementation which just visits `this` and thunks with nested thunks can override `ForAllThunk` to adjust the behaviour. PiperOrigin-RevId: 702601594 --- third_party/xla/xla/service/gpu/BUILD | 3 - .../service/gpu/compile_module_to_llvm_ir.h | 5 - third_party/xla/xla/service/gpu/fusions/BUILD | 2 +- .../xla/xla/service/gpu/gpu_executable.cc | 13 +- third_party/xla/xla/service/gpu/runtime/BUILD | 27 +--- .../gpu/runtime/command_buffer_thunk.cc | 8 ++ .../gpu/runtime/command_buffer_thunk.h | 3 + .../service/gpu/runtime/conditional_thunk.cc | 11 ++ .../service/gpu/runtime/conditional_thunk.h | 3 + .../gpu/runtime/dynamic_slice_thunk.cc | 6 + .../service/gpu/runtime/dynamic_slice_thunk.h | 3 + .../xla/service/gpu/runtime/for_all_thunks.cc | 128 ------------------ .../xla/service/gpu/runtime/for_all_thunks.h | 34 ----- .../gpu/runtime/for_all_thunks_test.cc | 4 +- .../service/gpu/runtime/sequential_thunk.cc | 11 ++ .../service/gpu/runtime/sequential_thunk.h | 3 + .../xla/xla/service/gpu/runtime/thunk.cc | 5 + .../xla/xla/service/gpu/runtime/thunk.h | 4 + .../xla/service/gpu/runtime/while_thunk.cc | 7 + .../xla/xla/service/gpu/runtime/while_thunk.h | 3 + 20 files changed, 81 insertions(+), 202 deletions(-) delete mode 100644 third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc delete mode 100644 third_party/xla/xla/service/gpu/runtime/for_all_thunks.h diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 99eebdfaa86e6c..702d8a773c7616 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -588,7 +588,6 @@ cc_library( "//xla/service:stream_pool", "//xla/service:xla_debug_info_manager", "//xla/service/gpu/runtime:annotation", - "//xla/service/gpu/runtime:for_all_thunks", "//xla/service/gpu/runtime:nccl_api", "//xla/service/gpu/runtime:nccl_clique", "//xla/service/gpu/runtime:sequential_thunk", @@ -1280,7 +1279,6 @@ cc_library( hdrs = [ "compile_module_to_llvm_ir.h", ], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":executable_proto_cc", ":execution_stream_assignment", @@ -1304,7 +1302,6 @@ cc_library( "//xla/service:hlo_proto_cc", "//xla/service:logical_buffer", "//xla/service/gpu/runtime:sequential_thunk", - "//xla/service/gpu/runtime:thunk", "//xla/stream_executor:device_description", "//xla/stream_executor:platform", "//xla/stream_executor/rocm:rocm_platform_id", diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h index 3e75fa156ac8c6..d3ddc4a14fc474 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_COMPILE_MODULE_TO_LLVM_IR_H_ #define XLA_SERVICE_GPU_COMPILE_MODULE_TO_LLVM_IR_H_ -#include #include #include #include @@ -34,7 +33,6 @@ limitations under the License. #include "xla/service/gpu/gpu_executable.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/runtime/sequential_thunk.h" -#include "xla/service/gpu/runtime/thunk.h" #include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -64,9 +62,6 @@ struct CompileModuleResults { bool use_original_allocations; }; -void ForAllThunks(const std::function& fn, - ThunkSequence* thunk_sequence); - absl::Status LoadCache(IrEmitterContext& ir_emitter_context, absl::string_view cache_file_path); diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 6841b56ec6131e..80c504bbff8571 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -374,7 +374,7 @@ cc_library( xla_cc_test( name = "triton_test", srcs = ["triton_test.cc"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + tags = ["gpu"], deps = [ ":fusion_emitter", ":fusions", diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index 6ccba1074f5866..405e831015c97f 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -50,7 +50,6 @@ limitations under the License. #include "xla/service/gpu/gpu_constants.h" #include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/service/gpu/runtime/annotation.h" -#include "xla/service/gpu/runtime/for_all_thunks.h" #include "xla/service/gpu/runtime/nccl_api.h" #include "xla/service/gpu/runtime/nccl_clique.h" #include "xla/service/gpu/runtime/sequential_thunk.h" @@ -99,13 +98,11 @@ using ::tsl::profiler::ScopedAnnotation; static absl::flat_hash_set GetExecutionStreamIds( const SequentialThunk& thunks) { absl::flat_hash_set stream_ids; - ForAllThunks( - [&](const Thunk* thunk) { - if (thunk->execution_stream_id() > 0) { - stream_ids.insert(thunk->execution_stream_id()); - } - }, - &thunks); + thunks.ForAllThunks([&](const Thunk* thunk) { + if (thunk->execution_stream_id() > 0) { + stream_ids.insert(thunk->execution_stream_id()); + } + }); return stream_ids; } diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index cf30f4f1a526b8..788de5ae39415e 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -370,6 +370,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", "@llvm-project//llvm:Support", @@ -473,6 +474,7 @@ cc_library( "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", @@ -557,6 +559,7 @@ cc_library( "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", @@ -1170,6 +1173,7 @@ cc_library( deps = [ ":annotation", ":thunk", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", @@ -1216,7 +1220,6 @@ cc_library( "//xla/core/collectives:rank_id", "//xla/ffi:execution_context", "//xla/hlo/ir:hlo", - "//xla/hlo/translate/mhlo_to_hlo:location_exporter", "//xla/service:buffer_assignment", "//xla/service:executable", "//xla/service:global_device_id", @@ -1227,44 +1230,27 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/tsl/lib/gtl:int_type", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:statusor", ], ) -cc_library( - name = "for_all_thunks", - srcs = ["for_all_thunks.cc"], - hdrs = ["for_all_thunks.h"], - deps = [ - ":command_buffer_thunk", - ":conditional_thunk", - ":dynamic_slice_thunk", - ":sequential_thunk", - ":thunk", - ":while_thunk", - "@com_google_absl//absl/functional:function_ref", - "@local_tsl//tsl/platform:casts", - ], -) - xla_cc_test( name = "for_all_thunks_test", srcs = ["for_all_thunks_test.cc"], + tags = ["gpu"], deps = [ ":command_buffer_cmd", ":command_buffer_thunk", ":conditional_thunk", ":dynamic_slice_thunk", - ":for_all_thunks", ":sequential_thunk", ":thunk", ":while_thunk", @@ -1315,6 +1301,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc index ac0f8c766d485d..a345da0d155dbd 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/base/thread_annotations.h" +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "xla/service/buffer_assignment.h" @@ -340,4 +341,11 @@ void CommandBufferThunk::EvictCommandBuffers() { } } +void CommandBufferThunk::ForAllThunks( + absl::FunctionRef fn) const { + fn(this); + if (thunks_ != nullptr) { + thunks_->ForAllThunks(fn); + } +} } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.h b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.h index a0442f3d711023..9154d8c067bac7 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" @@ -54,6 +55,8 @@ class CommandBufferThunk : public Thunk { absl::StatusOr GetCommandBufferAllocationAddress( const ExecuteParams& params, int64_t index); + void ForAllThunks(absl::FunctionRef fn) const override; + private: // Command buffer instantiated on a `se::StreamExecutor` instance, and // auxiliary state required for efficient command buffer updates. diff --git a/third_party/xla/xla/service/gpu/runtime/conditional_thunk.cc b/third_party/xla/xla/service/gpu/runtime/conditional_thunk.cc index 2b797e32e1d920..88c7273744cd17 100644 --- a/third_party/xla/xla/service/gpu/runtime/conditional_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/conditional_thunk.cc @@ -21,9 +21,11 @@ limitations under the License. #include #include +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/gpu/variant_visitor.h" #include "xla/status_macros.h" @@ -134,5 +136,14 @@ absl::Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) { return absl::OkStatus(); } +void ConditionalThunk::ForAllThunks( + absl::FunctionRef fn) const { + fn(this); + for (const std::unique_ptr& branch_thunk : + config_.branch_thunks) { + branch_thunk->ForAllThunks(fn); + } +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/conditional_thunk.h b/third_party/xla/xla/service/gpu/runtime/conditional_thunk.h index 323ee11fcaeb30..833df0539006c2 100644 --- a/third_party/xla/xla/service/gpu/runtime/conditional_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/conditional_thunk.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -71,6 +72,8 @@ class ConditionalThunk : public Thunk { return branch_index_buffer_index_; } + void ForAllThunks(absl::FunctionRef fn) const override; + private: const ConditionalThunkConfig config_; const BufferAllocation::Slice branch_index_buffer_index_; diff --git a/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.cc b/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.cc index fda7bc954ddc36..d0ec3a65283710 100644 --- a/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "llvm/ADT/STLExtras.h" @@ -254,5 +255,10 @@ absl::Status DynamicSliceThunk::ExecuteOnStream(const ExecuteParams& params) { return absl::OkStatus(); } +void DynamicSliceThunk::ForAllThunks( + absl::FunctionRef fn) const { + fn(this); + embedded_thunk_->ForAllThunks(fn); +} } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.h b/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.h index 1b283d036da2a3..6adc4a62f72d9d 100644 --- a/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "xla/literal.h" @@ -107,6 +108,8 @@ class DynamicSliceThunk : public Thunk { return offset_byte_sizes_; } + void ForAllThunks(absl::FunctionRef fn) const override; + private: std::unique_ptr embedded_thunk_; std::vector> arguments_; diff --git a/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc b/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc deleted file mode 100644 index ded4ab4bdb1652..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc +++ /dev/null @@ -1,128 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/for_all_thunks.h" - -#include -#include - -#include "absl/functional/function_ref.h" -#include "xla/service/gpu/runtime/command_buffer_thunk.h" -#include "xla/service/gpu/runtime/conditional_thunk.h" -#include "xla/service/gpu/runtime/dynamic_slice_thunk.h" -#include "xla/service/gpu/runtime/sequential_thunk.h" -#include "xla/service/gpu/runtime/thunk.h" -#include "xla/service/gpu/runtime/while_thunk.h" -#include "tsl/platform/casts.h" - -namespace xla::gpu { - -void ForAllThunks(absl::FunctionRef fn, - const Thunk* thunk) { - // Invoke `fn` with the `Thunk` itself first... - fn(thunk); - // ... and then handle all nested `Thunks` recursively. - switch (thunk->kind()) { - case Thunk::kDynamicSlice: - ForAllThunks(fn, tensorflow::down_cast(thunk) - ->embedded_thunk()); - break; - case Thunk::kCommandBuffer: - if (const std::unique_ptr& sequence = - tensorflow::down_cast(thunk)->thunks(); - sequence != nullptr) { - ForAllThunks(fn, sequence.get()); - } - break; - case Thunk::kConditional: - for (const std::unique_ptr& branch : - tensorflow::down_cast(thunk) - ->branch_thunks()) { - ForAllThunks(fn, branch.get()); - } - break; - case Thunk::kSequential: - ForAllThunks( - fn, &tensorflow::down_cast(thunk)->thunks()); - break; - case Thunk::kWhile: - ForAllThunks(fn, tensorflow::down_cast(thunk) - ->condition_thunk_sequence()); - ForAllThunks(fn, tensorflow::down_cast(thunk) - ->body_thunk_sequence()); - break; - case Thunk::kCholesky: - case Thunk::kConvolution: - case Thunk::kConvolutionReorder: - case Thunk::kCopy: - case Thunk::kCopyDone: - case Thunk::kCubSort: - case Thunk::kCublasLtMatmul: - case Thunk::kCustomCall: - case Thunk::kCustomKernel: - case Thunk::kCuDnn: - case Thunk::kFft: - case Thunk::kGemm: - case Thunk::kInfeed: - case Thunk::kKernel: - case Thunk::kMemset32BitValue: - case Thunk::kMemzero: - case Thunk::kNcclAllGather: - case Thunk::kNcclAllGatherStart: - case Thunk::kNcclAllGatherDone: - case Thunk::kNcclAllReduce: - case Thunk::kNcclAllReduceStart: - case Thunk::kNcclAllReduceDone: - case Thunk::kNcclCollectiveBroadcast: - case Thunk::kNcclCollectiveBroadcastStart: - case Thunk::kNcclCollectiveBroadcastDone: - case Thunk::kNcclCollectivePermute: - case Thunk::kNcclCollectivePermuteStart: - case Thunk::kNcclCollectivePermuteDone: - case Thunk::kNcclReduceScatter: - case Thunk::kNcclReduceScatterStart: - case Thunk::kNcclReduceScatterDone: - case Thunk::kNcclAllToAll: - case Thunk::kNcclAllToAllStart: - case Thunk::kNcclAllToAllDone: - case Thunk::kNcclSend: - case Thunk::kNcclSendDone: - case Thunk::kNcclRecv: - case Thunk::kNcclRecvDone: - case Thunk::kNorm: - case Thunk::kOutfeed: - case Thunk::kPartitionId: - case Thunk::kRecv: - case Thunk::kRecvDone: - case Thunk::kNcclGroupStart: - case Thunk::kNcclGroupDone: - case Thunk::kReplicaId: - case Thunk::kSend: - case Thunk::kSendDone: - case Thunk::kTriangularSolve: - case Thunk::kWaitForStreams: - // No default. All `Thunk::Kinds` must be handled. - break; - } -} - -void ForAllThunks(absl::FunctionRef fn, - const ThunkSequence* thunks) { - for (const std::unique_ptr& thunk : *thunks) { - ForAllThunks(fn, thunk.get()); - } -} - -} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime/for_all_thunks.h b/third_party/xla/xla/service/gpu/runtime/for_all_thunks.h deleted file mode 100644 index 6f6fc61c34c427..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/for_all_thunks.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_FOR_ALL_THUNKS_H_ -#define XLA_SERVICE_GPU_RUNTIME_FOR_ALL_THUNKS_H_ - -#include "absl/functional/function_ref.h" -#include "xla/service/gpu/runtime/thunk.h" - -namespace xla::gpu { - -// Recursively invokes `fn` for all `Thunks` in `root`, including those nested -// within other `Thunks` (e.g. the condition `Thunk` within a `WhileThunk`). -void ForAllThunks(absl::FunctionRef fn, const Thunk* thunk); - -// Same as above but for a `ThunkSequence` root. -void ForAllThunks(absl::FunctionRef fn, - const ThunkSequence* thunks); - -} // namespace xla::gpu - -#endif // XLA_SERVICE_GPU_RUNTIME_FOR_ALL_THUNKS_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/for_all_thunks_test.cc b/third_party/xla/xla/service/gpu/runtime/for_all_thunks_test.cc index 6220e55fbcc4a9..f2558f687ecc5f 100644 --- a/third_party/xla/xla/service/gpu/runtime/for_all_thunks_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/for_all_thunks_test.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime/for_all_thunks.h" - #include #include #include @@ -41,7 +39,7 @@ using ::testing::UnorderedElementsAre; // iterated `Thunks`. std::vector GetAllThunks(Thunk* root) { std::vector thunks; - ForAllThunks([&](const Thunk* thunk) { thunks.push_back(thunk); }, root); + root->ForAllThunks([&](const Thunk* thunk) { thunks.push_back(thunk); }); return thunks; } diff --git a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc index ac2e652dcb1537..b7f051d1d119ed 100644 --- a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc @@ -15,9 +15,12 @@ limitations under the License. #include "xla/service/gpu/runtime/sequential_thunk.h" +#include +#include #include #include +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "xla/service/gpu/runtime/annotation.h" @@ -83,5 +86,13 @@ absl::Status SequentialThunk::ExecuteOnStream(const ExecuteParams& params) { return absl::OkStatus(); } +void SequentialThunk::ForAllThunks( + absl::FunctionRef fn) const { + fn(this); + for (const std::unique_ptr& thunk : thunks_) { + thunk->ForAllThunks(fn); + } +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.h b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.h index d754d42f394865..ea4fbdc9d1f92f 100644 --- a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "xla/service/gpu/runtime/thunk.h" @@ -42,6 +43,8 @@ class SequentialThunk : public Thunk { absl::Status Initialize(const InitializeParams& params) override; absl::Status ExecuteOnStream(const ExecuteParams& params) override; + void ForAllThunks(absl::FunctionRef fn) const override; + private: // The list of sub-thunks. ThunkSequence thunks_; diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.cc b/third_party/xla/xla/service/gpu/runtime/thunk.cc index e7ccc1e360e3b5..f0c09a5541ba7f 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/thunk.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -366,5 +367,9 @@ bool Thunk::IsCollective() const { } } +void Thunk::ForAllThunks(absl::FunctionRef fn) const { + fn(this); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.h b/third_party/xla/xla/service/gpu/runtime/thunk.h index 3c07c4f852d2b5..acb5d1378f3a87 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/thunk.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -486,6 +487,9 @@ class Thunk { // Returns `true` if this thunk requires inter-GPU communication. bool IsCollective() const; + // Invokes `fn` with this thunk and all nested thunks. + virtual void ForAllThunks(absl::FunctionRef fn) const; + private: Kind kind_; std::string profile_annotation_; diff --git a/third_party/xla/xla/service/gpu/runtime/while_thunk.cc b/third_party/xla/xla/service/gpu/runtime/while_thunk.cc index d3ad896b10793b..12868d9ed71611 100644 --- a/third_party/xla/xla/service/gpu/runtime/while_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/while_thunk.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/cleanup/cleanup.h" +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" @@ -142,5 +143,11 @@ absl::Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) { return absl::OkStatus(); } +void WhileThunk::ForAllThunks(absl::FunctionRef fn) const { + fn(this); + condition_thunk_sequence_->ForAllThunks(fn); + body_thunk_sequence_->ForAllThunks(fn); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/while_thunk.h b/third_party/xla/xla/service/gpu/runtime/while_thunk.h index 3ab0069c1a897f..97a7a08808381b 100644 --- a/third_party/xla/xla/service/gpu/runtime/while_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/while_thunk.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" @@ -82,6 +83,8 @@ class WhileThunk : public Thunk { // code running on multiple threads. static absl::StatusOr CurrentLoopIteration(int64_t depth = 0); + void ForAllThunks(absl::FunctionRef fn) const override; + private: const BufferAllocation::Slice condition_result_buffer_index_; std::unique_ptr condition_thunk_sequence_;