diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 0df2c215219..664d26e7674 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -29,6 +29,7 @@ cc_library( ":computation_client", ":env_vars", ":pjrt_computation_client", + ":ifrt_computation_client", "@tsl//tsl/platform:stacktrace", ], ) @@ -70,6 +71,36 @@ cc_library( ], ) +cc_library( + name = "ifrt_computation_client", + srcs = [ + "ifrt_computation_client.cc", + ], + hdrs = [ + "ifrt_computation_client.h", + ], + deps = [ + ":computation_client", + ":debug_macros", + ":env_vars", + ":initialize_pjrt", + ":operation_manager", + ":stablehlo_helper", + ":tf_logging", + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla/client:xla_computation", + "@xla//xla/pjrt/distributed", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@tsl//tsl/profiler/lib:traceme", + "@tsl//tsl/platform/cloud:gcs_file_system", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "pjrt_computation_client", srcs = [ @@ -83,6 +114,7 @@ cc_library( ":debug_macros", ":env_hash", ":env_vars", + ":initialize_pjrt", ":operation_manager", ":profiler", ":stablehlo_helper", @@ -94,11 +126,7 @@ cc_library( "@xla//xla:shape_util", "@xla//xla/client:xla_computation", "@xla//xla/pjrt/distributed", - "@xla//xla/pjrt/gpu:se_gpu_pjrt_client", - "@xla//xla/service:gpu_plugin", "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt:tfrt_cpu_pjrt_client", - "@xla//xla/pjrt:pjrt_c_api_client", "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/platform/cloud:gcs_file_system", @@ -165,6 +193,25 @@ cc_test( ], ) +cc_library( + name = "initialize_pjrt", + srcs = ["initialize_pjrt.cc"], + hdrs = ["initialize_pjrt.h"], + deps = [ + ":debug_macros", + ":env_hash", + ":env_vars", + ":profiler", + ":sys_util", + ":tf_logging", + ":xla_coordinator", + "@xla//xla/service:gpu_plugin", + "@xla//xla/pjrt/gpu:se_gpu_pjrt_client", + "@xla//xla/pjrt:tfrt_cpu_pjrt_client", + "@xla//xla/pjrt:pjrt_c_api_client", + ], +) + cc_library( name = "metrics_analysis", srcs = ["metrics_analysis.cc"], @@ -410,28 +457,52 @@ ptxla_cc_test( ], ) -# disable for now since it is flaky on the upstream test. -# ptxla_cc_test( -# name = "pjrt_computation_client_test", -# srcs = ["pjrt_computation_client_test.cc"], -# deps = [ -# ":computation_client", -# ":pjrt_computation_client", -# ":tensor_source", -# "@xla//xla:literal", -# "@xla//xla:literal_util", -# "@xla//xla:shape_util", -# "@xla//xla:status", -# "@xla//xla:statusor", -# "@xla//xla/client:xla_builder", -# "@xla//xla/client:xla_computation", -# "@xla//xla/tests:literal_test_util", -# "@xla//xla/tools:hlo_module_loader", -# "@tsl//tsl/lib/core:status_test_util", -# "@tsl//tsl/platform:env", -# "@tsl//tsl/platform:errors", -# "@tsl//tsl/platform:logging", -# "@tsl//tsl/platform:test", -# "@tsl//tsl/platform:test_main", -# ], -# ) +ptxla_cc_test( + name = "pjrt_computation_client_test", + srcs = ["pjrt_computation_client_test.cc"], + deps = [ + ":computation_client", + ":pjrt_computation_client", + ":tensor_source", + "@xla//xla:literal", + "@xla//xla:literal_util", + "@xla//xla:shape_util", + "@xla//xla:status", + "@xla//xla:statusor", + "@xla//xla/client:xla_builder", + "@xla//xla/client:xla_computation", + "@xla//xla/tests:literal_test_util", + "@xla//xla/tools:hlo_module_loader", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + +ptxla_cc_test( + name = "ifrt_computation_client_test", + srcs = ["ifrt_computation_client_test.cc"], + deps = [ + ":computation_client", + ":ifrt_computation_client", + ":tensor_source", + "@xla//xla:literal", + "@xla//xla:literal_util", + "@xla//xla:shape_util", + "@xla//xla:status", + "@xla//xla:statusor", + "@xla//xla/client:xla_builder", + "@xla//xla/client:xla_computation", + "@xla//xla/tests:literal_test_util", + "@xla//xla/tools:hlo_module_loader", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 6e87c13b193..97f633a39f4 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -375,6 +375,8 @@ class ComputationClient { static int64_t GetDeviceOrdinal(const std::string& device); protected: + static constexpr auto spmd_device_str = "SPMD:0"; + // Metrics common to all client interfaces. static metrics::Metric* TransferToServerMetric(); static metrics::Metric* TransferToServerTransformMetric(); diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc new file mode 100644 index 00000000000..605826e6e6a --- /dev/null +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -0,0 +1,659 @@ +#include "torch_xla/csrc/runtime/ifrt_computation_client.h" + +#include +#include +#include + +#include "absl/strings/ascii.h" +#include "absl/synchronization/blocking_counter.h" +#include "absl/types/span.h" +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/env_hash.h" +#include "torch_xla/csrc/runtime/env_vars.h" +#include "torch_xla/csrc/runtime/initialize_pjrt.h" +#include "torch_xla/csrc/runtime/stablehlo_helper.h" +#include "torch_xla/csrc/runtime/tf_logging.h" +#include "torch_xla/csrc/runtime/xla_coordinator.h" +#include "tsl/profiler/lib/traceme.h" +#include "xla/client/xla_builder.h" +#include "xla/client/xla_computation.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "xla/shape.h" + +using xla::internal::XlaBuilderFriend; + +namespace torch_xla { +namespace runtime { + +namespace { + +// Builds a map from the device's global ordinal to its index in the `devices` +// array. +std::unordered_map build_index_map( + const std::vector& devices) { + std::unordered_map device_index; + for (int i = 0; i < devices.size(); ++i) { + std::vector device_spec = absl::StrSplit(devices[i], ':'); + XLA_CHECK_EQ(device_spec.size(), 2) + << "Invalid device specification: " << devices[i]; + int global_ordinal = std::stoi(device_spec[1]); + device_index[global_ordinal] = i; + } + return device_index; +} + +torch::lazy::hash_t hash_comp_env( + xla::ifrt::Client* client, + std::vector& ordered_devices) { + torch::lazy::hash_t hash = hash::HashXlaEnvVars(); + auto topology_desc = client->GetTopologyForDevices(ordered_devices); + if (topology_desc.ok()) { + // Some backends support a topology description which provides a better + // view of the specific compilation environment. + auto serialized = topology_desc.value()->Serialize(); + if (serialized.ok()) { + return torch::lazy::HashCombine( + hash, + torch::lazy::DataHash(serialized->data(), serialized->length())); + } + // If serialization fails, fallthrough to the manual approach. + } + std::string platform_name(client->platform_name()); + std::string platform_version(client->platform_version()); + hash = torch::lazy::HashCombine( + hash, torch::lazy::StringHash(platform_name.c_str())); + // platform_version incorporates libtpu version and hardware type. + hash = torch::lazy::HashCombine( + hash, torch::lazy::StringHash(platform_version.c_str())); + // Include global devices in the hash, ensuring order is consistent. + for (auto& device : ordered_devices) { + std::string device_str(device->ToString()); + hash = torch::lazy::HashCombine( + hash, torch::lazy::StringHash(device_str.c_str())); + } + return hash; +} + +} // namespace + +std::string IfrtComputationClient::IfrtDeviceToString( + xla::ifrt::Device* const device) const { + std::string platform = + absl::AsciiStrToUpper(device->client()->platform_name()); + int ordinal = global_ordinals_.at(device->id()); + std::string str = absl::StrFormat("%s:%d", platform, ordinal); + return str; +} + +std::vector IfrtComputationClient::IfrtDevicesToString( + absl::Span devices) const { + std::vector strs; + strs.reserve(devices.size()); + + for (auto* device : devices) { + strs.push_back(IfrtDeviceToString(device)); + } + + return strs; +} + +IfrtComputationClient::IfrtComputationClient() { + std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, ""); + std::unique_ptr pjrt_client; + std::tie(pjrt_client, coordinator_) = InitializePjRt(device_type); + + client_ = xla::ifrt::PjRtClient::Create(std::move(pjrt_client)); + + // PjRtDevice IDs are not guaranteed to be dense, so we need to track + // a device's global ordinal separately from its device ID. Order the + // devices by increasing ID to assign global ordinals. + std::vector ordered_devices(client_->device_count()); + std::partial_sort_copy(client_->devices().begin(), client_->devices().end(), + ordered_devices.begin(), ordered_devices.end(), + [](auto& a, auto& b) { return a->id() < b->id(); }); + for (auto* device : ordered_devices) { + global_ordinals_[device->id()] = global_ordinals_.size(); + std::string device_str = IfrtDeviceToString(device); + string_to_device_.emplace(device_str, device); + } + comp_env_hash_ = hash_comp_env(client_.get(), ordered_devices); + + auto tracked_devices = GetLocalDevices(); + tracked_devices.emplace_back(spmd_device_str); + operation_manager_ = std::move(OperationManager(std::move(tracked_devices))); +} + +IfrtComputationClient::~IfrtComputationClient() { + // In the GPU case, the PjRtClient depends on the DistributedRuntimeClient + // tracked in XlaCoordinator, so the PjRtClient must be destroyed first. + client_ = nullptr; + coordinator_ = nullptr; +} + +bool IfrtComputationClient::CoordinatorInitialized() const { + return coordinator_ != nullptr; +} + +void IfrtComputationClient::InitializeCoordinator(int global_rank, + int world_size, + std::string master_addr, + std::string port) { + XLA_CHECK(coordinator_ == nullptr) + << "Can only initialize the XlaCoordinator once."; + coordinator_ = std::make_unique(global_rank, world_size, + master_addr, port); +} + +XlaCoordinator& IfrtComputationClient::GetCoordinator() { + XLA_CHECK(coordinator_ != nullptr) + << "XlaCoordinator has not been initialized"; + return *coordinator_; +} + +void IfrtComputationClient::IfrtData::Assign( + const torch::lazy::BackendData& data) { + const IfrtData& ifrt_data = dynamic_cast(data); + if (&ifrt_data != this) { + buffer = ifrt_data.buffer; + } +} + +xla::OpSharding IfrtComputationClient::IfrtData::GetSharding() const { + XLA_CHECK(HasSharding()) << "Check HasSharding first"; + return *sharding_; +} + +ComputationClient::DataPtr IfrtComputationClient::CreateDataPlaceholder( + std::string device, xla::Shape shape) { + return std::make_shared(device, shape); +} + +std::vector IfrtComputationClient::GetDataShards( + ComputationClient::DataPtr data) { + tsl::profiler::TraceMe activity("IfrtComputationClient::GetDataShards", + tsl::profiler::TraceMeLevel::kInfo); + std::vector shards; + if (data->HasSharding()) { + auto ifrt_data = std::dynamic_pointer_cast(data); + std::vector> arrays = + ifrt_data->buffer + ->DisassembleIntoSingleDeviceArrays( + xla::ifrt::ArrayCopySemantics::kAlwaysCopy) + .value(); + + for (auto array : arrays) { + shards.push_back(std::make_shared( + IfrtDeviceToString(array->sharding().devices()[0]), array)); + } + } else { + shards.push_back(data); + } + return shards; +} + +ComputationClient::DataPtr IfrtComputationClient::GetDataShard( + ComputationClient::DataPtr data, size_t index) { + tsl::profiler::TraceMe activity("IfrtComputationClient::GetDataShard", + tsl::profiler::TraceMeLevel::kInfo); + return GetDataShards(data)[index]; +} + +ComputationClient::DataPtr IfrtComputationClient::WrapDataShards( + const std::vector& shards, std::string device, xla::Shape shape, + xla::OpSharding sharding) { + // TODO: implement CreateDataPlaceholder for sharded data + if (shards.size() == 0) { + TF_LOG(INFO) << "creating sharded placeholder"; + return std::make_shared( + device, shape, tsl::RCReference(), sharding); + } + std::vector> arrays; + std::vector shard_shapes; + for (auto& shard : shards) { + auto ifrt_shard = std::dynamic_pointer_cast(shard); + arrays.push_back(ifrt_shard->buffer); + shard_shapes.push_back(ifrt_shard->buffer->shape()); + } + xla::ifrt::Shape ifrt_shape(shape.dimensions()); + xla::ifrt::DeviceList devices_list({client_->addressable_devices().begin(), + client_->addressable_devices().end()}); + XLA_CHECK_EQ(shard_shapes.size(), devices_list.size()); + std::unique_ptr ifrt_sharding = + xla::ifrt::ConcreteSharding::Create(devices_list, xla::ifrt::MemoryKind(), + ifrt_shape, shard_shapes); + // TODO: Attach HloSharding instead when it is supported + // std::unique_ptr ifrt_sharding = + // xla::ifrt::HloSharding::Create( + // devices_list, + // xla::ifrt::MemoryKind(), + // xla::HloSharding::FromProto(sharding).value() + // ); + tsl::RCReference sharded_array = + client_ + ->AssembleArrayFromSingleDeviceArrays( + ifrt_shape, std::move(ifrt_sharding), absl::MakeSpan(arrays), + xla::ifrt::ArrayCopySemantics::kAlwaysCopy) + .value(); + return std::make_shared(device, shape, sharded_array, sharding); +} + +std::optional IfrtComputationClient::GetDataSharding( + DataPtr handle) { + auto ifrt_data = std::dynamic_pointer_cast(handle); + return ifrt_data->sharding_; +} + +std::vector IfrtComputationClient::TransferToServer( + absl::Span> tensors) { + auto timed = + std::make_shared(TransferToServerMetric()); + tsl::profiler::TraceMe activity("IfrtComputationClient::TransferToServer", + tsl::profiler::TraceMeLevel::kInfo); + std::vector datas; + datas.reserve(tensors.size()); + int64_t total_size = 0; + for (auto& tensor : tensors) { + xla::ifrt::Device* ifrt_device = StringToIfrtDevice(tensor->device()); + + total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape()); + + tsl::RCReference buffer = + client_ + ->MakeArrayFromHostBuffer( + tensor->data(), + xla::ifrt::ToDType(tensor->primitive_type()).value(), + xla::ifrt::Shape(tensor->dimensions()), tensor->byte_strides(), + // TODO: what is MemoryKind? + xla::ifrt::SingleDeviceSharding::Create( + ifrt_device, xla::ifrt::MemoryKind()), + xla::ifrt::Client::HostBufferSemantics:: + kImmutableUntilTransferCompletes, + [tensor, timed]() { /* frees tensor and timer */ }) + .value(); + + ComputationClient::DataPtr data = + std::make_shared(tensor->device(), tensor->shape(), buffer); + datas.push_back(data); + } + OutboundDataMetric()->AddSample(total_size); + CreateDataHandlesCounter()->AddValue(datas.size()); + + return datas; +} + +ComputationClient::DataPtr IfrtComputationClient::TransferShardsToServer( + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) { + tsl::profiler::TraceMe activity( + "IfrtComputationClient::TransferShardsToServer", + tsl::profiler::TraceMeLevel::kInfo); + // TODO(jonbolin): Consider using CopyToDevice when sharding is REPLICATED. + // We are opting out of CopyToDevice for now due to the synchronization + // issues observed in ShardingUtil::InputHandler, but because CopyToDevice + // directly copies buffers between devices using ICI, it can be much faster + // than transferring from the host to each device. + auto data_shards = TransferToServer(tensor_shards); + std::vector> arrays; + std::vector shard_shapes; + for (auto& shard : data_shards) { + auto ifrt_shard = std::dynamic_pointer_cast(shard); + arrays.push_back(ifrt_shard->buffer); + shard_shapes.push_back(ifrt_shard->buffer->shape()); + } + xla::ifrt::Shape ifrt_shape(shape.dimensions()); + xla::ifrt::DeviceList devices_list({client_->addressable_devices().begin(), + client_->addressable_devices().end()}); + std::unique_ptr ifrt_sharding = + xla::ifrt::ConcreteSharding::Create(devices_list, xla::ifrt::MemoryKind(), + ifrt_shape, shard_shapes); + // TODO: Attach HloSharding instead when it is supported + // std::unique_ptr ifrt_sharding = + // xla::ifrt::HloSharding::Create( + // devices_list, + // xla::ifrt::MemoryKind(), + // xla::HloSharding::FromProto(sharding).value() + // ); + tsl::RCReference sharded_array = + client_ + ->AssembleArrayFromSingleDeviceArrays( + ifrt_shape, std::move(ifrt_sharding), absl::MakeSpan(arrays), + xla::ifrt::ArrayCopySemantics::kAlwaysCopy) + .value(); + return std::make_shared(device, shape, sharded_array, sharding); +} + +ComputationClient::DataPtr IfrtComputationClient::CopyToDevice( + ComputationClient::DataPtr data, std::string dst) { + XLA_ERROR() << __FUNCTION__ << " not implemented"; +} + +tsl::RCReference IfrtComputationClient::ReplicateShardedData( + const std::shared_ptr handle) { + if (handle->buffer->sharding().devices().size() == 1) { + return handle->buffer; + } + + XLA_COUNTER("ReplicateShardedData", 1); + TF_VLOG(1) << "ReplicateShardedData (handle=" << handle->GetHandle() + << ", shape=" << handle->shape() << ")"; + // TODO: handle replicated data + xla::XlaBuilder builder("ReplicateShardedData"); + xla::Shape shape = handle->shape(); + builder.SetSharding(handle->GetSharding()); + + // perform a simple identity calculation to reassemble the input as + // replicated output. + xla::XlaOp x = xla::Parameter(&builder, 0, shape, "p0"); + builder.SetSharding(xla::HloSharding::Replicate().ToProto()); + xla::XlaOp scalar_zero_op = xla::ConvertElementType( + xla::ConstantR0(&builder, 0), shape.element_type()); + xla::XlaOp y = xla::Add(x, scalar_zero_op); + auto instruction = XlaBuilderFriend::GetInstruction(y); + *instruction->mutable_sharding() = xla::HloSharding::Replicate().ToProto(); + + xla::XlaComputation computation = + ConsumeValue(builder.Build(/*remove_dynamic_dimensions=*/false)); + xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); + + std::string device = GetDefaultDevice(); + std::vector instances; + instances.push_back({std::move(computation), device, + GetCompilationDevices(device, {}), &shape, + /*should_wrap_parameter=*/false, + /*is_sharded=*/true, + /*allow_spmd_sharding_propagation_to_output=*/false}); + std::vector< + std::shared_ptr> + computations = Compile(std::move(instances)); + + XLA_CHECK_EQ(handle->buffer->sharding().devices().size(), + GetLocalDevices().size()); + + torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions + execute_options; + + auto sharded_results = ExecuteReplicated(*computations.front(), {{handle}}, + GetLocalDevices(), execute_options); + auto replicated_output = + std::dynamic_pointer_cast(sharded_results[0]) + ->buffer->FullyReplicatedShard( + xla::ifrt::ArrayCopySemantics::kAlwaysCopy); + // TODO: sanity check outputs + return *replicated_output; +} + +std::vector IfrtComputationClient::TransferFromServer( + absl::Span handles) { + metrics::TimedSection timed(TransferFromServerMetric()); + tsl::profiler::TraceMe activity("IfrtComputationClient::TransferFromServer", + tsl::profiler::TraceMeLevel::kInfo); + std::vector literals; + literals.reserve(handles.size()); + int64_t total_size = 0; + for (auto handle : handles) { + // Use XLA replication to reassemble the sharded data. If input handle + // is not sharded, then it is a no-op. + auto ifrt_data = std::dynamic_pointer_cast(handle); + tsl::RCReference replicated_array = + ReplicateShardedData(ifrt_data); + + // TODO: handle dynamic shapes + auto& literal = literals.emplace_back( + xla::ShapeUtil::DeviceShapeToHostShape(ifrt_data->shape())); + std::vector byte_strides(literal.shape().dimensions_size()); + XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal.shape(), + absl::MakeSpan(byte_strides))); + XLA_CHECK_OK( + replicated_array + ->CopyToHostBuffer(literal.untyped_data(), byte_strides, + xla::ifrt::ArrayCopySemantics::kAlwaysCopy) + .Await()); + + total_size += literal.size_bytes(); + } + InboundDataMetric()->AddSample(total_size); + + return literals; +} + +std::vector IfrtComputationClient::Compile( + std::vector instances) { + metrics::TimedSection timed(CompileMetric()); + tsl::profiler::TraceMe activity("IfrtComputationClient::Compile", + tsl::profiler::TraceMeLevel::kInfo); + std::vector computations; + + for (auto& instance : instances) { + xla::CompileOptions compile_options; + if (instance.is_sharded) { + // TODO(yeounoh) multi-host, multi-slice configurations + compile_options.executable_build_options.set_use_spmd_partitioning(true); + // We can override the compiler's default behavior to replicate the + // outputs. + compile_options.executable_build_options + .set_allow_spmd_sharding_propagation_to_output( + {instance.allow_spmd_sharding_propagation_to_output}); + compile_options.executable_build_options.set_num_partitions( + client_->device_count()); + compile_options.executable_build_options.set_num_replicas(1); + compile_options.parameter_is_tupled_arguments = + instance.parameter_is_tupled_arguments; + + // TODO(244391366) verify this is correct for the collectives ops + xla::DeviceAssignment device_assignment(1, client_->device_count()); + // DeviceAssignment values must be the PjRtDevice ID, so we need to + // unwind the global ordinal mapping. + for (const auto& [device_id, global_ordinal] : global_ordinals_) { + device_assignment(0, global_ordinal) = device_id; + } + compile_options.executable_build_options.set_device_assignment( + device_assignment); + } else { + XLA_ERROR() << "Only SPMD compilation is supported"; + } + + // Convert HLO to StableHLO for Ifrt client compilation. + mlir::MLIRContext context; + mlir::ModuleOp mlir_module = + mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + torch_xla::ConvertHloToStableHlo(instance.computation.mutable_proto(), + &mlir_module); + std::unique_ptr executable = + ConsumeValue(client_->GetDefaultCompiler()->Compile( + std::make_unique(std::move(mlir_module)), + std::make_unique(compile_options))); + StableHloCompileCounter()->AddValue(1); + + const auto& hlo_modules = ConsumeValue(executable->GetHloModules()); + + std::shared_ptr ifrt_computation = + std::make_shared( + std::move(xla::XlaComputation(hlo_modules[0]->ToProto())), + instance.devices, std::move(executable)); + + computations.push_back(ifrt_computation); + + CreateCompileHandlesCounter()->AddValue(1); + } + + return computations; +} + +std::vector +IfrtComputationClient::ExecuteComputation( + const ComputationClient::Computation& computation, + absl::Span arguments, + const std::string& device, const ExecuteComputationOptions& options) { + // TODO: Implement sharded exec in IFRT + XLA_ERROR() << __FUNCTION__ << " not implemented"; +} + +std::vector +IfrtComputationClient::ExecuteReplicated( + const ComputationClient::Computation& computation, + const absl::Span arguments, + // TODO: devices isn't doing anything helpful here + absl::Span devices, + const ExecuteReplicatedOptions& options) { + // Shared ownership of the timed section ensures that it will only get logged + // once both `ExecuteReplicated` and the async work in `Execute` are + // complete; a copy is held from the lambda that releases it when done. + auto timed = + std::make_shared(ExecuteReplicatedMetric()); + tsl::profiler::TraceMe activity("IfrtComputationClient::ExecuteReplicated", + tsl::profiler::TraceMeLevel::kInfo); + const IfrtComputation& ifrt_computation = + dynamic_cast(computation); + + std::vector> argument_handles( + arguments.size()); + { + absl::BlockingCounter counter(arguments.size()); + + // Cost to handle one input argument. See tsl::ThreadPool::ParallelFor + // documentation + static const int32_t argument_handle_cost_ns = 1000; + pool_.ParallelFor(arguments.size(), argument_handle_cost_ns, + [&](int64_t start, int64_t end) { + for (int32_t i = start; i < end; ++i) { + auto ifrt_data = + std::dynamic_pointer_cast(arguments[i]); + argument_handles[i] = ifrt_data->buffer; + counter.DecrementCount(); + } + }); + counter.Wait(); + } + + xla::ExecuteOptions execute_options; + execute_options.untuple_result = options.explode_tuple; + execute_options.strict_shape_checking = true; + // TODO(yeounoh) currently only support single-slice execution + execute_options.multi_slice_config = nullptr; + + TF_VLOG(5) << "ExecuteReplicated acquiring IFRT device lock for " + << spmd_device_str; + auto op_tracker = operation_manager_.StartOperation(spmd_device_str); + TF_VLOG(5) << "ExecuteReplicated acquiring IFRT device lock for " + << spmd_device_str << " Done"; + + xla::ifrt::LoadedExecutable::ExecuteResult result = + ifrt_computation.executable + ->Execute(absl::MakeSpan(argument_handles), execute_options, + std::nullopt) + .value(); + + result.status.OnReady(std::move([timed, op_tracker = std::move(op_tracker)]( + xla::Status status) mutable { + timed.reset(); + TF_VLOG(3) + << "ExecuteReplicated returned_future->OnReady finished with status " + << status; + })); + + auto outputs = result.outputs; + + const std::vector& output_shardings = + ifrt_computation.output_shardings_ + ? *ifrt_computation.output_shardings_ + : std::vector(outputs.size(), + xla::HloSharding::Replicate().ToProto()); + XLA_CHECK_EQ(output_shardings.size(), outputs.size()); + + std::vector data_handles(outputs.size()); + { + absl::BlockingCounter counter(outputs.size()); + + // Cost to handle one output. See tsl::ThreadPool::ParallelFor + // documentation. + static const int32_t result_handle_cost_ns = 2000; + pool_.ParallelFor(outputs.size(), result_handle_cost_ns, + [&](int64_t start, int64_t end) { + for (int32_t i = start; i < end; ++i) { + data_handles[i] = std::make_shared( + spmd_device_str, outputs[i], output_shardings[i]); + counter.DecrementCount(); + } + }); + counter.Wait(); + } + + TF_VLOG(1) << "Returning " << data_handles.size() << " sharded outputs."; + return data_handles; +} + +size_t IfrtComputationClient::GetNumDevices() const { + return client_->addressable_device_count(); +} + +std::string IfrtComputationClient::GetDefaultDevice() const { + return IfrtDeviceToString(client_->addressable_devices()[0]); +} + +std::vector IfrtComputationClient::GetLocalDevices() const { + return IfrtDevicesToString(client_->addressable_devices()); +} + +std::vector IfrtComputationClient::GetAllDevices() const { + return IfrtDevicesToString(client_->devices()); +} + +int IfrtComputationClient::GetNumProcesses() const { + int max_process_index = client_->process_index(); + for (auto* device : client_->devices()) { + max_process_index = std::max(max_process_index, device->process_index()); + } + + return max_process_index + 1; +}; + +const absl::flat_hash_map< + std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& +IfrtComputationClient::GetDeviceAttributes(const std::string& device) { + return IfrtComputationClient::StringToIfrtDevice(device)->Attributes(); +} + +void IfrtComputationClient::SetReplicationDevices( + std::shared_ptr> devices) { + replication_devices_ = std::move(devices); +} + +std::shared_ptr> +IfrtComputationClient::GetReplicationDevices() { + return replication_devices_; +} + +xla::ifrt::Device* IfrtComputationClient::StringToIfrtDevice( + const std::string& device) { + XLA_CHECK(string_to_device_.find(device) != string_to_device_.end()) + << "Unknown device " << device; + xla::ifrt::Device* ifrt_device = string_to_device_[device]; + return ifrt_device; +} + +void IfrtComputationClient::WaitDeviceOps( + absl::Span devices) { + TF_VLOG(3) << "Waiting for " << absl::StrJoin(devices, ", "); + operation_manager_.WaitForDevices(devices.empty() ? GetLocalDevices() + : devices); +} + +std::map IfrtComputationClient::GetMetrics() const { + // TODO(jonbolin): Add any Ifrt-client-specific metrics here + return {}; +} + +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h new file mode 100644 index 00000000000..1c7e44bd6ee --- /dev/null +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -0,0 +1,219 @@ +#ifndef XLA_CLIENT_IFRT_COMPUTATION_CLIENT_H_ +#define XLA_CLIENT_IFRT_COMPUTATION_CLIENT_H_ + +#include + +#include +#include +#include + +#include "absl/types/span.h" +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/operation_manager.h" +#include "torch_xla/csrc/runtime/util.h" +#include "xla/client/xla_computation.h" +#include "xla/literal.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/shape.h" + +namespace torch_xla { +namespace runtime { + +class IfrtComputationClient : public ComputationClient { + public: + IfrtComputationClient(); + ~IfrtComputationClient(); + + DataPtr CreateDataPlaceholder(std::string device, xla::Shape shape) override; + + std::vector GetDataShards(DataPtr data) override; + + DataPtr GetDataShard(DataPtr data, size_t index) override; + + DataPtr WrapDataShards(const std::vector& shards, std::string device, + xla::Shape shape, xla::OpSharding sharding) override; + + std::optional GetDataSharding(DataPtr handle) override; + + std::vector TransferToServer( + absl::Span> tensors) override; + + std::vector TransferFromServer( + absl::Span handles) override; + + DataPtr TransferShardsToServer( + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) override; + + DataPtr CopyToDevice(DataPtr data, std::string dst) override; + + std::vector Compile( + std::vector instances) override; + + std::vector ExecuteComputation( + const Computation& computation, absl::Span arguments, + const std::string& device, + const ExecuteComputationOptions& options) override; + + std::vector ExecuteReplicated( + const Computation& computation, const absl::Span arguments, + absl::Span devices, + const ExecuteReplicatedOptions& options) override; + + size_t GetNumDevices() const override; + + std::string GetDefaultDevice() const override; + + std::vector GetLocalDevices() const override; + + std::vector GetAllDevices() const override; + + int GetProcessIndex() const override { return client_->process_index(); }; + + int GetNumProcesses() const override; + + const absl::flat_hash_map< + std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& + GetDeviceAttributes(const std::string& device) override; + + void SetReplicationDevices( + std::shared_ptr> devices) override; + + std::shared_ptr> GetReplicationDevices() override; + + void WaitDeviceOps(absl::Span devices) override; + + std::map GetMetrics() const override; + + void InitializeCoordinator(int global_rank, int world_size, + std::string master_addr, + std::string port) override; + + XlaCoordinator& GetCoordinator() override; + + bool CoordinatorInitialized() const override; + + torch::lazy::hash_t HashCompilationEnv() override { return comp_env_hash_; } + + // NOT IMPLEMENTED + + MemoryInfo GetMemoryInfo(const std::string& device) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + }; + + std::string SerializeComputation(const ComputationPtr computation) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + ComputationPtr DeserializeComputation( + const std::string& serialized) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + private: + std::shared_ptr client_; + std::unique_ptr coordinator_; + // global_ordinals_ tracks a map from PjRtDeviceId to the device's + // dense global ordinal. + std::unordered_map global_ordinals_; + std::unordered_map string_to_device_; + std::shared_ptr> replication_devices_; + OperationManager operation_manager_; + tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool( + tsl::Env::Default(), "ifrt", std::thread::hardware_concurrency()); + torch::lazy::hash_t comp_env_hash_; + + xla::ifrt::Device* StringToIfrtDevice(const std::string& device); + + std::string IfrtDeviceToString(xla::ifrt::Device* const device) const; + std::vector IfrtDevicesToString( + absl::Span devices) const; + + struct IfrtData : public Data { + IfrtData(std::string device, xla::Shape device_shape) + : Data(std::move(device), std::move(device_shape)) {} + + IfrtData(std::string device, xla::Shape device_shape, + tsl::RCReference buffer, + std::optional sharding = std::nullopt) + : Data(std::move(device), std::move(device_shape)), + buffer(buffer), + sharding_(sharding) {} + + IfrtData(std::string device, tsl::RCReference buffer, + std::optional sharding = std::nullopt) + : Data(std::move(device), + xla::ShapeUtil::MakeShape( + xla::ifrt::ToPrimitiveType(buffer->dtype()).value(), + buffer->shape().dims())), + buffer(buffer), + sharding_(sharding) {} + + Handle GetHandle() override { + XLA_CHECK(HasValue()) + << "buffer with shape " << shape().ToString() << " on device " + << device() << (buffer == nullptr ? " is null" : " is deleted"); + return reinterpret_cast(buffer.get()); + }; + void Assign(const torch::lazy::BackendData& data) override; + bool HasValue() const override { + return buffer != nullptr; // TODO: && !buffer->IsDeleted(); + }; + + bool HasSharding() const override { return sharding_.has_value(); } + + xla::OpSharding GetSharding() const override; + + std::string ToString() const override { + std::stringstream ss; + + if (HasSharding()) { + ss << "XLAShardedData: \n"; + ss << " Data Device: " << device() << "\n"; + ss << " Data Shape: " << shape().ToString() << "\n"; + ss << " OpSharding: " + << xla::HloSharding::FromProto(*sharding_)->ToString() << "\n"; + ss << " NumShards: " << buffer->sharding().devices().size() << "\n"; + } else { + ss << "XLAData: \n"; + ss << " Data Device: " << device() << "\n"; + ss << " Data Shape: " << shape().ToString() << "\n"; + ss << " Data Handle: "; + if (HasValue()) { + ss << reinterpret_cast(buffer.get()) << "\n"; + } else { + ss << "None\n"; + } + } + return ss.str(); + } + + std::optional sharding_; + tsl::RCReference buffer; + }; + + tsl::RCReference ReplicateShardedData( + const std::shared_ptr handle); + + struct IfrtComputation : public Computation { + IfrtComputation(xla::XlaComputation computation, + std::vector devices, + std::unique_ptr executable) + : Computation(std::move(computation), std::move(devices)), + executable(std::move(executable)) { + output_shardings_ = this->executable->GetOutputShardings(); + } + + std::unique_ptr executable; + std::optional> output_shardings_; + }; +}; + +} // namespace runtime +} // namespace torch_xla +#endif // XLA_CLIENT_IFRT_COMPUTATION_CLIENT_H_ diff --git a/torch_xla/csrc/runtime/ifrt_computation_client_test.cc b/torch_xla/csrc/runtime/ifrt_computation_client_test.cc new file mode 100644 index 00000000000..21194432dd4 --- /dev/null +++ b/torch_xla/csrc/runtime/ifrt_computation_client_test.cc @@ -0,0 +1,81 @@ +#include "torch_xla/csrc/runtime/ifrt_computation_client.h" + +#include + +#include +#include +#include + +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/tensor_source.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/test.h" +#include "xla/client/xla_builder.h" +#include "xla/client/xla_computation.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/tests/literal_test_util.h" + +namespace torch_xla { +namespace runtime { + +tsl::StatusOr MakeComputation() { + xla::Shape input_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {2, 2}); + xla::XlaBuilder builder("AddComputation"); + xla::XlaOp x = xla::Parameter(&builder, 0, input_shape, "x"); + xla::XlaOp y = xla::Parameter(&builder, 1, input_shape, "y"); + xla::XlaOp sum = xla::Add(x, y); + return builder.Build(); +} + +TEST(PjRtComputationClientTest, Init) { + // Get a CPU client. + tsl::setenv("PJRT_DEVICE", "CPU", true); + auto client = std::make_unique(); + std::string device = client->GetDefaultDevice(); + + // Compose a computation. + auto shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 2}); + std::vector instances; + instances.push_back(ComputationClient::CompileInstance( + std::move(MakeComputation().value()), device, + client->GetCompilationDevices(device, client->GetLocalDevices()), &shape, + /*parameter_is_tupled_arguments=*/false, /*is_sharded=*/true)); + + // Prepare inputs. + xla::Literal literal_x = + xla::LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}}); + xla::Literal literal_y = + xla::LiteralUtil::CreateR2({{5.0f, 6.0f}, {7.0f, 8.0f}}); + + // Compile the graph. + std::vector computations = + client->Compile(std::move(instances)); + + // Copy inputs to device. + ComputationClient::ExecuteReplicatedOptions options{}; + std::vector> args = { + std::make_shared(std::move(literal_x), device), + std::make_shared(std::move(literal_y), device)}; + + // Execute the graph. + std::vector results = client->ExecuteReplicated( + *computations[0], client->TransferToServer(absl::MakeConstSpan(args)), + {device}, options); + + // Copy the output from device back to host and assert correctness.. + ASSERT_EQ(results.size(), 1); + auto result_literals = client->TransferFromServer(results); + ASSERT_THAT(result_literals, ::testing::SizeIs(1)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal( + xla::LiteralUtil::CreateR2({{6.0f, 8.0f}, {10.0f, 12.0f}}), + result_literals[0])); +} + +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/initialize_pjrt.cc b/torch_xla/csrc/runtime/initialize_pjrt.cc new file mode 100644 index 00000000000..4e5f6ba7a1d --- /dev/null +++ b/torch_xla/csrc/runtime/initialize_pjrt.cc @@ -0,0 +1,135 @@ +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/env_vars.h" +#include "torch_xla/csrc/runtime/profiler.h" +#include "torch_xla/csrc/runtime/sys_util.h" +#include "torch_xla/csrc/runtime/tf_logging.h" +#include "torch_xla/csrc/runtime/xla_coordinator.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/pjrt/pjrt_c_api_client.h" +#include "xla/pjrt/tfrt_cpu_pjrt_client.h" + +namespace torch_xla { +namespace runtime { + +namespace { + +xla::GpuAllocatorConfig GetGpuAllocatorConfig() { + auto allocator_config = xla::GpuAllocatorConfig{}; + if (sys_util::GetEnvString(env::kEnvPjrtAllocatorCudaAsync, "").empty() && + sys_util::GetEnvString(env::kEnvPjrtAllocatorPreallocate, "").empty() && + sys_util::GetEnvString(env::kEnvPjrtAllocatorFraction, "").empty()) { + return allocator_config; + } + if (sys_util::GetEnvBool(env::kEnvPjrtAllocatorCudaAsync, false)) { + allocator_config.kind = xla::GpuAllocatorConfig::Kind::kCudaAsync; + } + allocator_config.preallocate = + sys_util::GetEnvBool(env::kEnvPjrtAllocatorPreallocate, true); + allocator_config.memory_fraction = + sys_util::GetEnvDouble(env::kEnvPjrtAllocatorFraction, 0.75); + return allocator_config; +} + +} // namespace + +std::tuple, std::unique_ptr> +InitializePjRt(const std::string& device_type) { + std::unique_ptr client; + std::unique_ptr coordinator; + + if (device_type == "CPU") { + TF_VLOG(1) << "Initializing PjRt CPU client..."; + bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true); + int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1); + client = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value()); + } else if (device_type == "TPU" || device_type == "TPU_C_API") { + TF_VLOG(1) << "Initializing TFRT TPU client..."; + // Prefer $TPU_LIBRARY_PATH if set + auto tpu_library_path = sys_util::GetEnvString( + env::kEnvTpuLibraryPath, + sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so")); + XLA_CHECK_OK(pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status()); + tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); + XLA_CHECK_OK(tpu_status); + client = std::move(xla::GetCApiClient("TPU").value()); + const PJRT_Api* c_api = + static_cast(client.get())->pjrt_c_api(); + profiler::RegisterProfilerForPlugin(c_api); + } else if (device_type == "TPU_LEGACY") { + XLA_ERROR() << "TPU_LEGACY client is no longer available."; + } else if (device_type == "GPU" || device_type == "CUDA" || + device_type == "ROCM") { + TF_VLOG(1) << "Initializing PjRt GPU client..."; + bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true); + int local_process_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0); + int global_process_rank = sys_util::GetEnvInt("RANK", local_process_rank); + int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1); + int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size); + std::string master_addr = + runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); + std::string port = runtime::sys_util::GetEnvString( + "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); + + xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; + xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; + auto allowed_devices = + std::make_optional>(std::set{local_process_rank}); + if (global_world_size > 1) { + // Use the XlaCoordinator as the distributed key-value store. + coordinator = std::make_unique( + global_process_rank, global_world_size, master_addr, port); + std::shared_ptr distributed_client = + coordinator->GetClient(); + std::string key_prefix = "gpu:"; + kv_get = [distributed_client, key_prefix]( + std::string_view k, + absl::Duration timeout) -> xla::StatusOr { + return distributed_client->BlockingKeyValueGet( + absl::StrCat(key_prefix, k), timeout); + }; + kv_put = [distributed_client, key_prefix]( + std::string_view k, std::string_view v) -> xla::Status { + return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); + }; + } + TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" + << global_process_rank << ", num_nodes=" << global_world_size; + xla::GpuClientOptions options; + options.allocator_config = GetGpuAllocatorConfig(); + options.node_id = global_process_rank; + options.num_nodes = global_world_size; + options.allowed_devices = allowed_devices; + options.platform_name = "gpu"; + options.should_stage_host_to_device_transfers = true; + options.kv_get = kv_get; + options.kv_put = kv_put; + client = std::move(xla::GetStreamExecutorGpuClient(options).value()); + } else if (device_type == "XPU") { + TF_VLOG(1) << "Initializing PjRt XPU client..."; + XLA_CHECK_OK( + pjrt::LoadPjrtPlugin( + "xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so")) + .status()); + client = std::move(xla::GetCApiClient("XPU").value()); + } else if (device_type == "NEURON") { + TF_VLOG(1) << "Initializing PjRt NEURON client..."; + XLA_CHECK_OK(pjrt::LoadPjrtPlugin("NEURON", sys_util::GetEnvString( + env::kEnvNeuronLibraryPath, + "libneuronpjrt.so")) + .status()); + client = std::move(xla::GetCApiClient("NEURON").value()); + } else { + XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, + device_type); + } + + XLA_CHECK(client.get() != nullptr); + + return {std::move(client), std::move(coordinator)}; +} + +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/initialize_pjrt.h b/torch_xla/csrc/runtime/initialize_pjrt.h new file mode 100644 index 00000000000..012927fe474 --- /dev/null +++ b/torch_xla/csrc/runtime/initialize_pjrt.h @@ -0,0 +1,15 @@ +#ifndef XLA_CLIENT_INITIALIZE_PJRT_CLIENT_H_ +#define XLA_CLIENT_INITIALIZE_PJRT_CLIENT_H_ + +#include "xla/pjrt/pjrt_client.h" + +namespace torch_xla { +namespace runtime { + +std::tuple, std::unique_ptr> +InitializePjRt(const std::string& device_type); + +} +} // namespace torch_xla + +#endif // XLA_CLIENT_INITIALIZE_PJRT_H_ diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 3796b3e41ab..1ca27518282 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -13,6 +13,7 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/env_hash.h" #include "torch_xla/csrc/runtime/env_vars.h" +#include "torch_xla/csrc/runtime/initialize_pjrt.h" #include "torch_xla/csrc/runtime/operation_manager.h" #include "torch_xla/csrc/runtime/profiler.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" @@ -25,14 +26,8 @@ #include "xla/client/xla_computation.h" #include "xla/layout_util.h" #include "xla/literal.h" -#include "xla/pjrt/c/pjrt_c_api.h" -#include "xla/pjrt/distributed/distributed.h" -#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" -#include "xla/pjrt/pjrt_api.h" -#include "xla/pjrt/pjrt_c_api_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" #include "xla/shape.h" using xla::internal::XlaBuilderFriend; @@ -42,8 +37,6 @@ namespace runtime { namespace { -static std::string spmd_device_str = "SPMD:0"; - // Builds a map from the device's global ordinal to its index in the `devices` // array. std::unordered_map build_index_map( @@ -68,26 +61,8 @@ xla::Shape host_output_shape(xla::PjRtBuffer* buffer) { return xla::ShapeUtil::DeviceShapeToHostShape(shape); } -xla::GpuAllocatorConfig GetGpuAllocatorConfig() { - auto allocator_config = xla::GpuAllocatorConfig{}; - if (sys_util::GetEnvString(env::kEnvPjrtAllocatorCudaAsync, "").empty() && - sys_util::GetEnvString(env::kEnvPjrtAllocatorPreallocate, "").empty() && - sys_util::GetEnvString(env::kEnvPjrtAllocatorFraction, "").empty()) { - return allocator_config; - } - if (sys_util::GetEnvBool(env::kEnvPjrtAllocatorCudaAsync, false)) { - allocator_config.kind = xla::GpuAllocatorConfig::Kind::kCudaAsync; - } - allocator_config.preallocate = - sys_util::GetEnvBool(env::kEnvPjrtAllocatorPreallocate, true); - allocator_config.memory_fraction = - sys_util::GetEnvDouble(env::kEnvPjrtAllocatorFraction, 0.75); - return allocator_config; -} - torch::lazy::hash_t hash_comp_env( - std::shared_ptr client, - std::vector& ordered_devices) { + xla::PjRtClient* client, std::vector& ordered_devices) { torch::lazy::hash_t hash = hash::HashXlaEnvVars(); auto topology_desc = client->GetTopologyDescription(); if (topology_desc.ok()) { @@ -142,93 +117,7 @@ std::vector PjRtComputationClient::PjRtDevicesToString( PjRtComputationClient::PjRtComputationClient() { std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, ""); - if (device_type == "CPU") { - TF_VLOG(1) << "Initializing PjRt CPU client..."; - bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true); - int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1); - client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value()); - } else if (device_type == "TPU" || device_type == "TPU_C_API") { - TF_VLOG(1) << "Initializing TFRT TPU client..."; - // Prefer $TPU_LIBRARY_PATH if set - auto tpu_library_path = sys_util::GetEnvString( - env::kEnvTpuLibraryPath, - sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so")); - XLA_CHECK_OK(pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status()); - tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); - XLA_CHECK_OK(tpu_status); - client_ = std::move(xla::GetCApiClient("TPU").value()); - const PJRT_Api* c_api = - static_cast(client_.get())->pjrt_c_api(); - profiler::RegisterProfilerForPlugin(c_api); - } else if (device_type == "TPU_LEGACY") { - XLA_ERROR() << "TPU_LEGACY client is no longer available."; - } else if (device_type == "GPU" || device_type == "CUDA" || - device_type == "ROCM") { - TF_VLOG(1) << "Initializing PjRt GPU client..."; - bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true); - int local_process_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0); - int global_process_rank = sys_util::GetEnvInt("RANK", local_process_rank); - int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1); - int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size); - std::string master_addr = - runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); - std::string port = runtime::sys_util::GetEnvString( - "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); - - xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; - xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; - auto allowed_devices = - std::make_optional>(std::set{local_process_rank}); - if (global_world_size > 1) { - // Use the XlaCoordinator as the distributed key-value store. - coordinator_ = std::make_unique( - global_process_rank, global_world_size, master_addr, port); - std::shared_ptr distributed_client = - coordinator_->GetClient(); - std::string key_prefix = "gpu:"; - kv_get = [distributed_client, key_prefix]( - std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - return distributed_client->BlockingKeyValueGet( - absl::StrCat(key_prefix, k), timeout); - }; - kv_put = [distributed_client, key_prefix]( - std::string_view k, std::string_view v) -> xla::Status { - return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); - }; - } - TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" - << global_process_rank << ", num_nodes=" << global_world_size; - xla::GpuClientOptions options; - options.allocator_config = GetGpuAllocatorConfig(); - options.node_id = global_process_rank; - options.num_nodes = global_world_size; - options.allowed_devices = allowed_devices; - options.platform_name = "gpu"; - options.should_stage_host_to_device_transfers = true; - options.kv_get = kv_get; - options.kv_put = kv_put; - client_ = std::move(xla::GetStreamExecutorGpuClient(options).value()); - } else if (device_type == "XPU") { - TF_VLOG(1) << "Initializing PjRt XPU client..."; - XLA_CHECK_OK( - pjrt::LoadPjrtPlugin( - "xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so")) - .status()); - client_ = std::move(xla::GetCApiClient("XPU").value()); - } else if (device_type == "NEURON") { - TF_VLOG(1) << "Initializing PjRt NEURON client..."; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin("NEURON", sys_util::GetEnvString( - env::kEnvNeuronLibraryPath, - "libneuronpjrt.so")) - .status()); - client_ = std::move(xla::GetCApiClient("NEURON").value()); - } else { - XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, - device_type); - } - - XLA_CHECK(client_.get() != nullptr); + std::tie(client_, coordinator_) = InitializePjRt(device_type); // PjRtDevice IDs are not guaranteed to be dense, so we need to track // a device's global ordinal separately from its device ID. Order the @@ -242,7 +131,7 @@ PjRtComputationClient::PjRtComputationClient() { std::string device_str = PjRtDeviceToString(device); string_to_device_.emplace(device_str, device); } - comp_env_hash_ = hash_comp_env(client_, ordered_devices); + comp_env_hash_ = hash_comp_env(client_.get(), ordered_devices); auto tracked_devices = GetLocalDevices(); tracked_devices.emplace_back(spmd_device_str); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index a54b3c0b3b1..b43ffdecb45 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -110,7 +110,7 @@ class PjRtComputationClient : public ComputationClient { }; private: - std::shared_ptr client_; + std::unique_ptr client_; std::unique_ptr coordinator_; // global_ordinals_ tracks a map from PjRtDeviceId to the device's // dense global ordinal. diff --git a/torch_xla/csrc/runtime/runtime.cc b/torch_xla/csrc/runtime/runtime.cc index e2f69c44e47..feb2a0844c6 100644 --- a/torch_xla/csrc/runtime/runtime.cc +++ b/torch_xla/csrc/runtime/runtime.cc @@ -3,6 +3,7 @@ #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/env_vars.h" +#include "torch_xla/csrc/runtime/ifrt_computation_client.h" #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "tsl/platform/stacktrace_handler.h" @@ -19,8 +20,13 @@ ComputationClient* GetComputationClient() { std::unique_ptr client; + static bool use_ifrt = sys_util::GetEnvBool("XLA_USE_IFRT", false); if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") { - client = std::make_unique(); + if (use_ifrt) { + client = std::make_unique(); + } else { + client = std::make_unique(); + } } else { XLA_ERROR() << "$PJRT_DEVICE is not set." << std::endl; }