Skip to content

Commit

Permalink
Move ForAllThunk into Thunk base class
Browse files Browse the repository at this point in the history
`ForAllThunk` is currently implemented as a free function. It's a switch statement
which handles all thunk types. Each type gets downcasted to the real type for
handling the nested buffers.

This change makes `ForAllThunk` a virtual function on `Thunk`. `Thunk` provides
a default implementation which just visits `this` and thunks with nested thunks
can override `ForAllThunk` to adjust the behaviour.

PiperOrigin-RevId: 702601594
  • Loading branch information
beckerhe authored and tensorflower-gardener committed Dec 4, 2024
1 parent b014417 commit 68d4e6f
Show file tree
Hide file tree
Showing 20 changed files with 81 additions and 202 deletions.
3 changes: 0 additions & 3 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,6 @@ cc_library(
"//xla/service:stream_pool",
"//xla/service:xla_debug_info_manager",
"//xla/service/gpu/runtime:annotation",
"//xla/service/gpu/runtime:for_all_thunks",
"//xla/service/gpu/runtime:nccl_api",
"//xla/service/gpu/runtime:nccl_clique",
"//xla/service/gpu/runtime:sequential_thunk",
Expand Down Expand Up @@ -1280,7 +1279,6 @@ cc_library(
hdrs = [
"compile_module_to_llvm_ir.h",
],
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
deps = [
":executable_proto_cc",
":execution_stream_assignment",
Expand All @@ -1304,7 +1302,6 @@ cc_library(
"//xla/service:hlo_proto_cc",
"//xla/service:logical_buffer",
"//xla/service/gpu/runtime:sequential_thunk",
"//xla/service/gpu/runtime:thunk",
"//xla/stream_executor:device_description",
"//xla/stream_executor:platform",
"//xla/stream_executor/rocm:rocm_platform_id",
Expand Down
5 changes: 0 additions & 5 deletions third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#ifndef XLA_SERVICE_GPU_COMPILE_MODULE_TO_LLVM_IR_H_
#define XLA_SERVICE_GPU_COMPILE_MODULE_TO_LLVM_IR_H_

#include <functional>
#include <memory>
#include <string>
#include <vector>
Expand All @@ -34,7 +33,6 @@ limitations under the License.
#include "xla/service/gpu/gpu_executable.h"
#include "xla/service/gpu/ir_emitter_context.h"
#include "xla/service/gpu/runtime/sequential_thunk.h"
#include "xla/service/gpu/runtime/thunk.h"
#include "xla/service/hlo.pb.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
Expand Down Expand Up @@ -64,9 +62,6 @@ struct CompileModuleResults {
bool use_original_allocations;
};

void ForAllThunks(const std::function<void(Thunk*)>& fn,
ThunkSequence* thunk_sequence);

absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
absl::string_view cache_file_path);

Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ cc_library(
xla_cc_test(
name = "triton_test",
srcs = ["triton_test.cc"],
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
tags = ["gpu"],
deps = [
":fusion_emitter",
":fusions",
Expand Down
13 changes: 5 additions & 8 deletions third_party/xla/xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ limitations under the License.
#include "xla/service/gpu/gpu_constants.h"
#include "xla/service/gpu/gpu_executable_run_options.h"
#include "xla/service/gpu/runtime/annotation.h"
#include "xla/service/gpu/runtime/for_all_thunks.h"
#include "xla/service/gpu/runtime/nccl_api.h"
#include "xla/service/gpu/runtime/nccl_clique.h"
#include "xla/service/gpu/runtime/sequential_thunk.h"
Expand Down Expand Up @@ -99,13 +98,11 @@ using ::tsl::profiler::ScopedAnnotation;
static absl::flat_hash_set<ExecutionStreamId> GetExecutionStreamIds(
const SequentialThunk& thunks) {
absl::flat_hash_set<ExecutionStreamId> stream_ids;
ForAllThunks(
[&](const Thunk* thunk) {
if (thunk->execution_stream_id() > 0) {
stream_ids.insert(thunk->execution_stream_id());
}
},
&thunks);
thunks.ForAllThunks([&](const Thunk* thunk) {
if (thunk->execution_stream_id() > 0) {
stream_ids.insert(thunk->execution_stream_id());
}
});
return stream_ids;
}

Expand Down
27 changes: 7 additions & 20 deletions third_party/xla/xla/service/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ cc_library(
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/status",
"@com_google_absl//absl/synchronization",
"@llvm-project//llvm:Support",
Expand Down Expand Up @@ -473,6 +474,7 @@ cc_library(
"//xla/stream_executor:stream_executor_h",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
Expand Down Expand Up @@ -557,6 +559,7 @@ cc_library(
"//xla/stream_executor:stream_executor_h",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/status",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
Expand Down Expand Up @@ -1170,6 +1173,7 @@ cc_library(
deps = [
":annotation",
":thunk",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@local_tsl//tsl/platform:errors",
Expand Down Expand Up @@ -1216,7 +1220,6 @@ cc_library(
"//xla/core/collectives:rank_id",
"//xla/ffi:execution_context",
"//xla/hlo/ir:hlo",
"//xla/hlo/translate/mhlo_to_hlo:location_exporter",
"//xla/service:buffer_assignment",
"//xla/service:executable",
"//xla/service:global_device_id",
Expand All @@ -1227,44 +1230,27 @@ cc_library(
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_executor_h",
"//xla/tsl/lib/gtl:int_type",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@llvm-project//mlir:IR",
"@local_tsl//tsl/platform:statusor",
],
)

cc_library(
name = "for_all_thunks",
srcs = ["for_all_thunks.cc"],
hdrs = ["for_all_thunks.h"],
deps = [
":command_buffer_thunk",
":conditional_thunk",
":dynamic_slice_thunk",
":sequential_thunk",
":thunk",
":while_thunk",
"@com_google_absl//absl/functional:function_ref",
"@local_tsl//tsl/platform:casts",
],
)

xla_cc_test(
name = "for_all_thunks_test",
srcs = ["for_all_thunks_test.cc"],
tags = ["gpu"],
deps = [
":command_buffer_cmd",
":command_buffer_thunk",
":conditional_thunk",
":dynamic_slice_thunk",
":for_all_thunks",
":sequential_thunk",
":thunk",
":while_thunk",
Expand Down Expand Up @@ -1315,6 +1301,7 @@ cc_library(
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include <vector>

#include "absl/base/thread_annotations.h"
#include "absl/functional/function_ref.h"
#include "absl/status/status.h"
#include "absl/synchronization/mutex.h"
#include "xla/service/buffer_assignment.h"
Expand Down Expand Up @@ -340,4 +341,11 @@ void CommandBufferThunk::EvictCommandBuffers() {
}
}

void CommandBufferThunk::ForAllThunks(
absl::FunctionRef<void(const Thunk*)> fn) const {
fn(this);
if (thunks_ != nullptr) {
thunks_->ForAllThunks(fn);
}
}
} // namespace xla::gpu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.

#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/functional/function_ref.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
Expand Down Expand Up @@ -54,6 +55,8 @@ class CommandBufferThunk : public Thunk {
absl::StatusOr<se::DeviceMemoryBase> GetCommandBufferAllocationAddress(
const ExecuteParams& params, int64_t index);

void ForAllThunks(absl::FunctionRef<void(const Thunk*)> fn) const override;

private:
// Command buffer instantiated on a `se::StreamExecutor` instance, and
// auxiliary state required for efficient command buffer updates.
Expand Down
11 changes: 11 additions & 0 deletions third_party/xla/xla/service/gpu/runtime/conditional_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ limitations under the License.
#include <utility>
#include <variant>

#include "absl/functional/function_ref.h"
#include "absl/status/status.h"
#include "absl/synchronization/mutex.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/gpu/runtime/sequential_thunk.h"
#include "xla/service/gpu/runtime/thunk.h"
#include "xla/service/gpu/variant_visitor.h"
#include "xla/status_macros.h"
Expand Down Expand Up @@ -134,5 +136,14 @@ absl::Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) {
return absl::OkStatus();
}

void ConditionalThunk::ForAllThunks(
absl::FunctionRef<void(const Thunk*)> fn) const {
fn(this);
for (const std::unique_ptr<SequentialThunk>& branch_thunk :
config_.branch_thunks) {
branch_thunk->ForAllThunks(fn);
}
}

} // namespace gpu
} // namespace xla
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/gpu/runtime/conditional_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.

#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/functional/function_ref.h"
#include "absl/status/status.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -71,6 +72,8 @@ class ConditionalThunk : public Thunk {
return branch_index_buffer_index_;
}

void ForAllThunks(absl::FunctionRef<void(const Thunk*)> fn) const override;

private:
const ConditionalThunkConfig config_;
const BufferAllocation::Slice branch_index_buffer_index_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include <vector>

#include "absl/container/inlined_vector.h"
#include "absl/functional/function_ref.h"
#include "absl/status/status.h"
#include "absl/synchronization/mutex.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -254,5 +255,10 @@ absl::Status DynamicSliceThunk::ExecuteOnStream(const ExecuteParams& params) {
return absl::OkStatus();
}

void DynamicSliceThunk::ForAllThunks(
absl::FunctionRef<void(const Thunk*)> fn) const {
fn(this);
embedded_thunk_->ForAllThunks(fn);
}
} // namespace gpu
} // namespace xla
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.

#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/functional/function_ref.h"
#include "absl/status/status.h"
#include "absl/synchronization/mutex.h"
#include "xla/literal.h"
Expand Down Expand Up @@ -107,6 +108,8 @@ class DynamicSliceThunk : public Thunk {
return offset_byte_sizes_;
}

void ForAllThunks(absl::FunctionRef<void(const Thunk*)> fn) const override;

private:
std::unique_ptr<SequentialThunk> embedded_thunk_;
std::vector<std::optional<BufferAllocation::Slice>> arguments_;
Expand Down
Loading

0 comments on commit 68d4e6f

Please sign in to comment.