From fdd13af626655f38bf22b8316ab1d50cf481a346 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 5 Oct 2023 17:39:34 +0000 Subject: [PATCH 01/33] Start IFRT prototype --- torch_xla/csrc/runtime/BUILD | 37 + .../csrc/runtime/ifrt_computation_client.cc | 796 ++++++++++++++++++ .../csrc/runtime/ifrt_computation_client.h | 241 ++++++ torch_xla/csrc/runtime/runtime.cc | 9 +- 4 files changed, 1082 insertions(+), 1 deletion(-) create mode 100644 torch_xla/csrc/runtime/ifrt_computation_client.cc create mode 100644 torch_xla/csrc/runtime/ifrt_computation_client.h diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 0df2c215219..e38baa03a93 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -29,6 +29,7 @@ cc_library( ":computation_client", ":env_vars", ":pjrt_computation_client", + ":ifrt_computation_client", "@tsl//tsl/platform:stacktrace", ], ) @@ -70,6 +71,42 @@ cc_library( ], ) +cc_library( + name = "ifrt_computation_client", + srcs = [ + "ifrt_computation_client.cc", + ], + hdrs = [ + "ifrt_computation_client.h", + ], + deps = [ + ":computation_client", + # TODO: why do I need this? + # ":pjrt_computation_client", + ":debug_macros", + ":env_vars", + ":multi_wait", + ":stablehlo_helper", + ":tf_logging", + ":thread_pool", + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla/client:xla_computation", + "@xla//xla/pjrt/distributed", + "@xla//xla/pjrt/gpu:se_gpu_pjrt_client", + "@xla//xla/service:gpu_plugin", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:tfrt_cpu_pjrt_client", + "@xla//xla/pjrt:pjrt_c_api_client", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@tsl//tsl/profiler/lib:traceme", + "@tsl//tsl/platform/cloud:gcs_file_system", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "pjrt_computation_client", srcs = [ diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc new file mode 100644 index 00000000000..60ff8056c68 --- /dev/null +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -0,0 +1,796 @@ +#include "torch_xla/csrc/runtime/ifrt_computation_client.h" + +#include +#include +#include + +#include "absl/strings/ascii.h" +#include "absl/types/span.h" +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/env_vars.h" +#include "torch_xla/csrc/runtime/multi_wait.h" +#include "torch_xla/csrc/runtime/stablehlo_helper.h" +#include "torch_xla/csrc/runtime/tf_logging.h" +#include "torch_xla/csrc/runtime/thread_pool.h" +#include "tsl/profiler/lib/traceme.h" +#include "xla/client/xla_builder.h" +#include "xla/client/xla_computation.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/pjrt/pjrt_c_api_client.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/tfrt_cpu_pjrt_client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/shape.h" + +using xla::internal::XlaBuilderFriend; + +namespace torch_xla { +namespace runtime { + +namespace { + +static std::string spmd_device_str = "SPMD:0"; + +// Initializes a distributed runtime client if dist_service_addr is specified +std::shared_ptr +MaybeInitializeDistributedRuntimeClient(int local_rank, + std::string dist_service_addr) { + std::shared_ptr client; + if (!dist_service_addr.empty()) { + xla::DistributedRuntimeClient::Options options; + /* TODO(jonbolin): Use global rank for multi-host setup */ + options.node_id = local_rank; + client = xla::GetDistributedRuntimeClient(dist_service_addr, options); + XLA_CHECK(client->Connect().ok()) + << "Failed to initialize distributed runtime client"; + } + return std::move(client); +} + +// Builds a map from the device's global ordinal to its index in the `devices` +// array. +std::unordered_map build_index_map( + const std::vector& devices) { + std::unordered_map device_index; + for (int i = 0; i < devices.size(); ++i) { + std::vector device_spec = absl::StrSplit(devices[i], ':'); + XLA_CHECK_EQ(device_spec.size(), 2) + << "Invalid device specification: " << devices[i]; + int global_ordinal = std::stoi(device_spec[1]); + device_index[global_ordinal] = i; + } + return device_index; +} + +// Builds the xla::Shape of the output xla::Literal on the host. +// xla::Shape host_output_shape(xla::PjRtBuffer* buffer) { +// xla::Shape shape = xla::ShapeUtil::MakeShape( +// buffer->element_type(), buffer->logical_dimensions().value()); +// *shape.mutable_layout() = buffer->layout(); + +// return xla::ShapeUtil::DeviceShapeToHostShape(shape); +// } + +} // namespace + +std::string IfrtComputationClient::PjRtDeviceToString( + xla::PjRtDevice* const device) const { + std::string platform = + absl::AsciiStrToUpper(device->client()->platform_name()); + int ordinal = global_ordinals_.at(device->id()); + std::string str = absl::StrFormat("%s:%d", platform, ordinal); + return str; +} + +std::vector IfrtComputationClient::PjRtDevicesToString( + absl::Span devices) const { + std::vector strs; + strs.reserve(devices.size()); + + for (auto* device : devices) { + strs.push_back(PjRtDeviceToString(device)); + } + + return strs; +} + +IfrtComputationClient::IfrtComputationClient() { + std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, ""); + if (device_type == "CPU") { + TF_VLOG(1) << "Initializing PjRt CPU client..."; + bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true); + int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1); + client_ = xla::ifrt::PjRtClient::Create(std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value())); + } else if (device_type == "TPU" || device_type == "TPU_C_API") { + TF_VLOG(1) << "Initializing TFRT TPU client..."; + XLA_CHECK_OK(pjrt::LoadPjrtPlugin( + "tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so"))); + tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); + XLA_CHECK(tpu_status.ok()); + client_ = xla::ifrt::PjRtClient::Create(std::move(xla::GetCApiClient("TPU").value())); + } else { + XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, + device_type); + } + + XLA_CHECK(client_.get() != nullptr); + + // PjRtDevice IDs are not guaranteed to be dense, so we need to track + // a device's global ordinal separately from its device ID. Order the + // devices by increasing ID to assign global ordinals. + std::vector ordered_devices(client_->device_count()); + std::partial_sort_copy(client_->devices().begin(), client_->devices().end(), + ordered_devices.begin(), ordered_devices.end(), + [](auto& a, auto& b) { return a->id() < b->id(); }); + for (auto* device : ordered_devices) { + global_ordinals_[device->id()] = global_ordinals_.size(); + std::string device_str = PjRtDeviceToString(device); + string_to_device_.emplace(device_str, device); + device_locks_.emplace(device_str, std::make_unique()); + } + // manually create the device_locks for SPMD device + device_locks_.emplace(spmd_device_str, std::make_unique()); +} + +void IfrtComputationClient::IfrtData::Assign( + const torch::lazy::BackendData& data) { + const IfrtData& pjrt_data = dynamic_cast(data); + if (&pjrt_data != this) { + buffer = pjrt_data.buffer; + } +} + +ComputationClient::DataPtr IfrtComputationClient::CreateDataPlaceholder( + std::string device, xla::Shape shape) { + return std::make_shared(device, shape); +} + +std::vector IfrtComputationClient::GetDataShards( + ComputationClient::DataPtr data) { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + // tsl::profiler::TraceMe activity("IfrtComputationClient::GetDataShards", + // tsl::profiler::TraceMeLevel::kInfo); + // std::vector shards; + // if (PjRtShardedData* sharded_data = + // dynamic_cast(data.get())) { + // for (auto shard : sharded_data->shards) { + // shards.push_back(std::make_shared( + // shard->device(), shard->shape(), shard->buffer)); + // } + // } else { + // shards.push_back(data); + // } + // return shards; +} + +ComputationClient::DataPtr IfrtComputationClient::GetDataShard( + ComputationClient::DataPtr data, size_t index) { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + // tsl::profiler::TraceMe activity("IfrtComputationClient::GetDataShard", + // tsl::profiler::TraceMeLevel::kInfo); + // if (PjRtShardedData* sharded_data = + // dynamic_cast(data.get())) { + // XLA_CHECK_LE(index, sharded_data->shards.size()) + // << "GetDataShard out of range with index: " << index + // << " and num of shard: " << sharded_data->shards.size(); + // std::shared_ptr shard = sharded_data->shards[index]; + // return std::make_shared(shard->device(), shard->shape(), + // shard->buffer); + // } else { + // return data; + // } +} + +ComputationClient::DataPtr IfrtComputationClient::WrapDataShards( + const std::vector& shards, std::string device, xla::Shape shape, + xla::OpSharding sharding) { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + // std::vector> pjrt_data_shards; + // pjrt_data_shards.reserve(shards.size()); + // for (auto& shard : shards) { + // XLA_CHECK(shard != nullptr); + // auto pjrt_shard = dynamic_cast(shard.get()); + // pjrt_data_shards.push_back(std::make_shared( + // pjrt_shard->device(), pjrt_shard->shape(), pjrt_shard->buffer)); + // } + // return std::make_shared(device, shape, pjrt_data_shards, + // sharding); +} + +std::optional IfrtComputationClient::GetDataSharding( + DataPtr handle) { + return std::nullopt; + // if (auto sharded_data = dynamic_cast(handle.get())) { + // return sharded_data->GetSharding(); + // } + // return std::optional(); +} + +std::vector IfrtComputationClient::TransferToServer( + absl::Span tensors) { + metrics::TimedSection timed(TransferToServerMetric()); + tsl::profiler::TraceMe activity("IfrtComputationClient::TransferToServer", + tsl::profiler::TraceMeLevel::kInfo); + std::vector datas; + datas.reserve(tensors.size()); + int64_t total_size = 0; + for (auto& tensor : tensors) { + xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor.device); + + auto literal = std::make_shared(tensor.shape); + tensor.populate_fn(tensor, literal->untyped_data(), literal->size_bytes()); + std::vector byte_strides(literal->shape().dimensions_size()); + XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal->shape(), + absl::MakeSpan(byte_strides))); + total_size += literal->size_bytes(); + + // Avoid use-after-free on `literal` due to unsequenced move and use. + xla::Literal* literal_pointer = literal.get(); + tsl::RCReference buffer = + client_ + ->MakeArrayFromHostBuffer( + literal_pointer->untyped_data(), + xla::ifrt::ToDType(literal_pointer->shape().element_type()).value(), + xla::ifrt::Shape(literal_pointer->shape().dimensions()), + byte_strides, + // TODO: what is MemoryKind? + xla::ifrt::SingleDeviceSharding::Create(pjrt_device, xla::ifrt::MemoryKind()), + xla::PjRtClient::HostBufferSemantics:: + kImmutableUntilTransferCompletes, + [literal{std::move(literal)}]() { /* frees literal */ }) + .value(); + + ComputationClient::DataPtr data = + std::make_shared(tensor.device, tensor.shape, buffer); + datas.push_back(data); + } + OutboundDataMetric()->AddSample(total_size); + CreateDataHandlesCounter()->AddValue(datas.size()); + + return datas; +} + +ComputationClient::DataPtr IfrtComputationClient::TransferShardsToServer( + absl::Span tensor_shards, std::string device, + xla::Shape shape, xla::OpSharding sharding) { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + // tsl::profiler::TraceMe activity( + // "IfrtComputationClient::TransferShardsToServer", + // tsl::profiler::TraceMeLevel::kInfo); + // // TODO(jonbolin): Consider using CopyToDevice when sharding is REPLICATED. + // // We are opting out of CopyToDevice for now due to the synchronization + // // issues observed in ShardingUtil::InputHandler, but because CopyToDevice + // // directly copies buffers between devices using ICI, it can be much faster + // // than transferring from the host to each device. + // auto data_shards = TransferToServer(tensor_shards); + // std::vector> pjrt_data_shards; + // for (auto& shard : data_shards) { + // auto pjrt_shard = dynamic_cast(shard.get()); + // pjrt_data_shards.push_back(std::make_shared( + // pjrt_shard->device(), pjrt_shard->shape(), pjrt_shard->buffer)); + // } + // return std::make_shared(device, shape, pjrt_data_shards, + // sharding); +} + +ComputationClient::DataPtr IfrtComputationClient::CopyToDevice( + ComputationClient::DataPtr data, std::string dst) { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + // tsl::profiler::TraceMe activity("IfrtComputationClient::CopyToDevice", + // tsl::profiler::TraceMeLevel::kInfo); + // const PjRtData* pjrt_data = dynamic_cast(data.get()); + // XLA_CHECK(pjrt_data->HasValue()) << "Can't copy invalid device data."; + + // xla::PjRtDevice* dst_device = StringToPjRtDevice(dst); + // XLA_CHECK(dst_device->IsAddressable()) << dst << "is not addressable."; + + // // Returns error if the buffer is already on `dst_device`. + // xla::StatusOr> status_or = + // pjrt_data->buffer->CopyToDevice(dst_device); + // XLA_CHECK(status_or.ok()) + // << pjrt_data->device() << " buffer already exists on " << dst; + + // return std::make_shared(dst, pjrt_data->shape(), + // std::move(status_or.value())); +} + +ComputationClient::DataPtr IfrtComputationClient::ReplicateShardedData( + const ComputationClient::DataPtr& handle) { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + // if (PjRtShardedData* sharded_data = + // dynamic_cast(handle.get())) { + // XLA_COUNTER("ReplicateShardedData", 1); + // TF_VLOG(1) << "ReplicateShardedData (handle=" << handle->GetHandle() + // << ", shape=" << handle->shape() << ")"; + // if (sharded_data->GetSharding().type() == xla::OpSharding::REPLICATED) { + // // Data is replicated, return the first shard + // return sharded_data->shards[0]; + // } + // xla::XlaBuilder builder("ReplicateShardedData"); + // xla::Shape shape = sharded_data->shape(); + // builder.SetSharding(sharded_data->GetSharding()); + + // // perform a simple identity calculation to reassemble the input as + // // replicated output. + // xla::XlaOp x = xla::Parameter(&builder, 0, shape, "p0"); + // builder.SetSharding(xla::HloSharding::Replicate().ToProto()); + // xla::XlaOp scalar_zero_op = xla::ConvertElementType( + // xla::ConstantR0(&builder, 0), shape.element_type()); + // xla::XlaOp y = xla::Add(x, scalar_zero_op); + // auto instruction = XlaBuilderFriend::GetInstruction(y); + // *instruction->mutable_sharding() = xla::HloSharding::Replicate().ToProto(); + + // xla::XlaComputation computation = + // ConsumeValue(builder.Build(/*remove_dynamic_dimensions=*/false)); + // xla::ProgramShape program_shape = + // ConsumeValue(computation.GetProgramShape()); + + // std::string device = GetDefaultDevice(); + // std::vector + // instances; + // instances.push_back({std::move(computation), device, + // GetCompilationDevices(device, {}), &shape, + // /*should_wrap_parameter=*/false, + // /*is_sharded=*/true, + // /*allow_spmd_sharding_propagation_to_output=*/false}); + // std::vector< + // std::shared_ptr> + // computations = Compile(std::move(instances)); + + // auto shards = sharded_data->shards; + // XLA_CHECK_EQ(shards.size(), GetLocalDevices().size()); + // auto device_index = build_index_map(GetLocalDevices()); + + // std::vector> arguments_by_device( + // GetLocalDevices().size(), std::vector(1)); + // for (auto shard : shards) { + // std::vector device_spec = + // absl::StrSplit(shard->device(), ':'); + // XLA_CHECK_EQ(device_spec.size(), 2) + // << "Invalid device specification: " << shard->device(); + // int device_i = device_index[std::stoi(device_spec[1])]; + // TF_VLOG(3) << shard->device() << " is mapped to local device index " + // << device_i; + // arguments_by_device[device_i][0] = shard; + // } + // torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions + // execute_options; + // auto sharded_results = + // ExecuteReplicated(*computations.front(), arguments_by_device, + // GetLocalDevices(), execute_options); + // XLA_CHECK(sharded_results.size() > 0) + // << "empty ExecuteReplicated results returned."; + // XLA_CHECK(sharded_results[0].size() == 1) + // << "Wrong number of outputs, expected: 1, actual: " + // << sharded_results[0].size(); + // return sharded_results[0][0]; + // } + // return handle; +} + +std::vector IfrtComputationClient::TransferFromServer( + absl::Span handles) { + metrics::TimedSection timed(TransferFromServerMetric()); + tsl::profiler::TraceMe activity("IfrtComputationClient::TransferFromServer", + tsl::profiler::TraceMeLevel::kInfo); + std::vector literals; + literals.reserve(handles.size()); + int64_t total_size = 0; + for (auto handle : handles) { + // Use XLA replication to reassemble the sharded data. If input handle + // is not sharded, then it is a no-op. + // auto new_handle = ReplicateShardedData(handle); + auto pjrt_data = std::dynamic_pointer_cast(handle); + + // TODO: handle dynamic shapes + auto& literal = literals.emplace_back( + xla::ShapeUtil::DeviceShapeToHostShape(pjrt_data->shape())); + std::vector byte_strides(literal.shape().dimensions_size()); + XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal.shape(), + absl::MakeSpan(byte_strides))); + XLA_CHECK_OK(pjrt_data->buffer->CopyToHostBuffer( + literal.untyped_data(), byte_strides, xla::ifrt::ArrayCopySemantics::kAlwaysCopy).Await()); + + total_size += literal.size_bytes(); + } + InboundDataMetric()->AddSample(total_size); + + return literals; +} + +std::vector IfrtComputationClient::Compile( + std::vector instances) { + metrics::TimedSection timed(CompileMetric()); + tsl::profiler::TraceMe activity("IfrtComputationClient::Compile", + tsl::profiler::TraceMeLevel::kInfo); + std::vector computations; + + for (auto& instance : instances) { + xla::CompileOptions compile_options; + if (instance.is_sharded) { + // TODO(yeounoh) multi-host, multi-slice configurations + compile_options.executable_build_options.set_use_spmd_partitioning(true); + // We can override the compiler's default behavior to replicate the + // outputs. Setting this to true would wrapping the sharded outputs in + // PjRtShardedData. + compile_options.executable_build_options + .set_allow_spmd_sharding_propagation_to_output( + {instance.allow_spmd_sharding_propagation_to_output}); + compile_options.executable_build_options.set_num_partitions( + client_->device_count()); + compile_options.executable_build_options.set_num_replicas(1); + compile_options.parameter_is_tupled_arguments = + instance.parameter_is_tupled_arguments; + + // TODO(244391366) verify this is correct for the collectives ops + xla::DeviceAssignment device_assignment(1, client_->device_count()); + // DeviceAssignment values must be the PjRtDevice ID, so we need to + // unwind the global ordinal mapping. + for (const auto& [device_id, global_ordinal] : global_ordinals_) { + device_assignment(0, global_ordinal) = device_id; + } + compile_options.executable_build_options.set_device_assignment( + device_assignment); + } else { + // TODO(wcromar): set compile_options.argument_layouts, enable strict + // shapes + compile_options.executable_build_options.set_num_partitions(1); + compile_options.executable_build_options.set_num_replicas( + client_->device_count()); + compile_options.parameter_is_tupled_arguments = + instance.parameter_is_tupled_arguments; + + xla::DeviceAssignment device_assignment(client_->device_count(), 1); + // DeviceAssignment values must be the PjRtDevice ID, so we need to + // unwind the global ordinal mapping. + for (const auto& [device_id, global_ordinal] : global_ordinals_) { + device_assignment(global_ordinal, 0) = device_id; + } + compile_options.executable_build_options.set_device_assignment( + device_assignment); + } + + // Convert HLO to StableHLO for Ifrt client compilation. + mlir::MLIRContext context; + mlir::ModuleOp mlir_module = + mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + torch_xla::runtime::ConvertHloToStableHlo( + instance.computation.mutable_proto(), &mlir_module); + std::unique_ptr executable = ConsumeValue(client_->GetDefaultCompiler()->Compile( + std::make_unique(std::move(mlir_module)), + std::make_unique(compile_options))); + StableHloCompileCounter()->AddValue(1); + + const auto& hlo_modules = ConsumeValue(executable->GetHloModules()); + xla::HloComputation* hlo_computation = hlo_modules[0]->entry_computation(); + + std::shared_ptr pjrt_computation = + std::make_shared( + std::move(xla::XlaComputation(hlo_modules[0]->ToProto())), + instance.devices, std::move(executable)); + + computations.push_back(pjrt_computation); + + CreateCompileHandlesCounter()->AddValue(1); + } + + return computations; +} + +std::vector +IfrtComputationClient::ExecuteComputation( + const ComputationClient::Computation& computation, + absl::Span arguments, + const std::string& device, const ExecuteComputationOptions& options) { + // Shared ownership of the timed section ensures that it will only get logged + // once both `ExecuteComputation` and the async work in `ExecuteSharded` are + // complete; a copy is held from the lambda that releases it when done. + auto timed = std::make_shared(ExecuteMetric()); + tsl::profiler::TraceMe activity("IfrtComputationClient::ExecuteComputation", + tsl::profiler::TraceMeLevel::kInfo); + TF_VLOG(1) << "Executing Ifrt computation on " << device; + const IfrtComputation& pjrt_computation = + dynamic_cast(computation); + + xla::PjRtDevice* pjrt_device = StringToPjRtDevice(device); + XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); + + std::vector> buffers; + buffers.reserve(arguments.size()); + for (auto& argument : arguments) { + const IfrtData* pjrt_data = dynamic_cast(argument.get()); + + // XLA_CHECK(pjrt_device == pjrt_data->buffer->device()) + // << pjrt_device->DebugString() << " vs " + // << pjrt_data->buffer->device()->DebugString(); + buffers.push_back(pjrt_data->buffer); + } + + xla::ExecuteOptions execute_options; + execute_options.untuple_result = options.explode_tuple; + execute_options.strict_shape_checking = false; + + // Required as of cl/518733871 + execute_options.use_major_to_minor_data_layout_for_callbacks = true; + + xla::ifrt::DeviceList device_list({pjrt_device}); + xla::ifrt::LoadedExecutable::ExecuteResult result = + pjrt_computation.executable + ->Execute(absl::MakeSpan(buffers), execute_options, device_list) + .value(); + + xla::ifrt::Future returned_future = result.status; + + auto results = result.outputs; + std::vector datas; + datas.reserve(results.size()); + for (auto& result : results) { + tsl::RCReference buffer = std::move(result); + + std::shared_ptr data = + std::make_shared(device, std::move(buffer)); + + datas.push_back(data); + } + CreateDataHandlesCounter()->AddValue(datas.size()); + + auto mwait = std::make_shared(1); + auto lockfn = [&, this, device, returned_future = std::move(returned_future), + timed]() mutable { + TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " + << device; + // Grab the shared lock and block the `WaitDeviceOps` until buffer is + // ready. + // TODO(JackCaoG): This lock should acquired outside of the lockfn and + // passed in. It is possible that lockfn started after ExecuteComputation + // released the xla_graph_executor lock, which will create a short windows + // where device is unlcoked while execution is still running. + auto lock = lock_device_shared(device); + TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device + << " Done"; + // Signal that `ExecuteSharded` has completed for the ExecuteTime + // metric. Copies the `timed` shared pointer into the lambda. + XLA_CHECK(returned_future.IsValid()) + << "returned_future in ExecuteComputation is empty"; + returned_future.OnReady( + [timed, lock = std::move(lock)](xla::Status unused) mutable { + timed.reset(); + TF_VLOG(3) << "ExecuteComputation returned_future->OnReady finished"; + }); + }; + + env::ScheduleIoClosure(util::MultiWait::Completer(mwait, std::move(lockfn))); + + TF_VLOG(1) << "Returning " << datas.size() << " results"; + return datas; +} + +std::vector> +IfrtComputationClient::ExecuteReplicated( + const ComputationClient::Computation& computation, + const std::vector>& arguments, + absl::Span devices, + const ExecuteReplicatedOptions& options) { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + // Shared ownership of the timed section ensures that it will only get logged + // once both `ExecuteReplicated` and the async work in `Execute` are + // complete; a copy is held from the lambda that releases it when done. + // auto timed = + // std::make_shared(ExecuteReplicatedMetric()); + // tsl::profiler::TraceMe activity("IfrtComputationClient::ExecuteReplicated", + // tsl::profiler::TraceMeLevel::kInfo); + // const PjRtComputation& pjrt_computation = + // dynamic_cast(computation); + // XLA_CHECK(devices.size() == arguments.size()) + // << "ExecuteReplicated over " << devices.size() << " devices, but " + // << arguments.size() << " arguments devices."; + // auto mwait_argument = std::make_shared(devices.size()); + // std::vector> argument_handles(devices.size()); + // { + // tsl::profiler::TraceMe activity( + // "IfrtComputationClient::ExecuteReplicated_argument_handle", + // tsl::profiler::TraceMeLevel::kInfo); + // for (int32_t i = 0; i < devices.size(); ++i) { + // auto buffer_converter = [&, i]() { + // xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); + // XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); + + // std::vector buffers; + // for (auto& argument : arguments[i]) { + // const PjRtData* pjrt_data = dynamic_cast(argument.get()); + + // XLA_CHECK(pjrt_device == pjrt_data->buffer->device()) + // << pjrt_device->DebugString() << " vs " + // << pjrt_data->buffer->device()->DebugString(); + // buffers.push_back(pjrt_data->buffer.get()); + // } + // argument_handles[i] = std::move(buffers); + // }; + // env::ScheduleIoClosure(util::MultiWait::Completer( + // mwait_argument, std::move(buffer_converter))); + // } + // mwait_argument->Wait(); + // } + + // xla::ExecuteOptions execute_options; + // execute_options.untuple_result = options.explode_tuple; + // execute_options.strict_shape_checking = true; + // // TODO(yeounoh) currently only support single-slice execution + // execute_options.multi_slice_config = nullptr; + + // // Required as of cl/518733871 + // execute_options.use_major_to_minor_data_layout_for_callbacks = true; + + // std::optional>> returned_futures( + // devices.size()); + // std::vector>> results; + // { + // tsl::profiler::TraceMe activity( + // "IfrtComputationClient::ExecuteReplicated_execute", + // tsl::profiler::TraceMeLevel::kInfo); + // results = pjrt_computation.executable + // ->Execute(std::move(argument_handles), execute_options, + // returned_futures) + // .value(); + // } + + // std::vector> data_handles; + // data_handles.reserve(results.size()); + // std::vector dims(results.size()); + + // { + // tsl::profiler::TraceMe activity( + // "IfrtComputationClient::ExecuteReplicated_result_handle", + // tsl::profiler::TraceMeLevel::kInfo); + // for (int32_t i = 0; i < results.size(); ++i) { + // xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); + // XLA_CHECK(pjrt_device->IsAddressable()) + // << pjrt_device->DebugString() << " is not addressable."; + + // std::vector datas; + // datas.reserve(results[i].size()); + // dims[i] = results[i].size(); + // for (int32_t j = 0; j < results[i].size(); ++j) { + // std::unique_ptr buffer = std::move(results[i][j]); + // XLA_CHECK(pjrt_device == buffer->device()) + // << "Exepcted device: " << pjrt_device->DebugString() + // << " vs. actual device: " << buffer->device()->DebugString(); + + // std::shared_ptr data = + // std::make_shared(devices[i], std::move(buffer)); + // datas.push_back(data); + // } + // data_handles.push_back(datas); + // } + // } + + // auto mwait = std::make_shared(1); + // auto lockfn = [&, this, returned_futures = std::move(*returned_futures), + // timed]() mutable { + // // Grab the shared lock and block the `WaitDeviceOps` until buffer is + // // ready. Since this is the SPMD code path. There is no points to grab + // // devices lock for every individual device. + // TF_VLOG(5) << "ExecuteReplicated acquiring PJRT device lock for " + // << spmd_device_str; + // auto lock = lock_device_shared(spmd_device_str); + // TF_VLOG(5) << "ExecuteReplicated acquiring PJRT device lock for " + // << spmd_device_str << " Done"; + // // Signal that `ExecuteReplicated` has completed for one of the devices + // // the ExecuteReplicatedTime metric. Here, we assume that all devices + // // will finish execution roughly at the same time, hence only use one of + // // the returned_futures. Copies the `timed` shared pointer into the + // // lambda. + // XLA_CHECK(returned_futures[0].IsValid()) + // << "returned_future in ExecuteReplicated is empty"; + // returned_futures[0].OnReady( + // [timed, lock = std::move(lock)](xla::Status unused) mutable { + // timed.reset(); + // TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady finished"; + // }); + // }; + // env::ScheduleIoClosure(util::MultiWait::Completer(mwait, std::move(lockfn))); + + // TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results " + // << "with dimensions [" << absl::StrJoin(dims, ",") << "]."; + // return data_handles; +} + +size_t IfrtComputationClient::GetNumDevices() const { + return client_->addressable_device_count(); +} + +std::string IfrtComputationClient::GetDefaultDevice() const { + return PjRtDeviceToString(client_->addressable_devices()[0]); +} + +std::vector IfrtComputationClient::GetLocalDevices() const { + return PjRtDevicesToString(client_->addressable_devices()); +} + +std::vector IfrtComputationClient::GetAllDevices() const { + return PjRtDevicesToString(client_->devices()); +} + +int IfrtComputationClient::GetNumProcesses() const { + int max_process_index = client_->process_index(); + for (auto* device : client_->devices()) { + max_process_index = std::max(max_process_index, device->process_index()); + } + + return max_process_index + 1; +}; + +const absl::flat_hash_map< + std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& +IfrtComputationClient::GetDeviceAttributes(const std::string& device) { + return IfrtComputationClient::StringToPjRtDevice(device)->Attributes(); +} + +void IfrtComputationClient::SetReplicationDevices( + std::shared_ptr> devices) { + replication_devices_ = std::move(devices); +} + +std::shared_ptr> +IfrtComputationClient::GetReplicationDevices() { + return replication_devices_; +} + +xla::PjRtDevice* IfrtComputationClient::StringToPjRtDevice( + const std::string& device) { + XLA_CHECK(string_to_device_.find(device) != string_to_device_.end()) + << "Unknown device " << device; + xla::PjRtDevice* pjrt_device = string_to_device_[device]; + return pjrt_device; +} + +std::shared_lock IfrtComputationClient::lock_device_shared( + const std::string& device) { + std::shared_lock lock(*device_locks_[device]); + return lock; +} + +std::unique_lock IfrtComputationClient::lock_device( + const std::string& device) { + std::unique_lock lock(*device_locks_[device]); + return lock; +} + +void IfrtComputationClient::WaitDeviceOps( + const std::vector& devices) { + std::unordered_set wait_devices; + if (!devices.empty()) { + for (auto& device_str : devices) { + wait_devices.insert(device_str); + } + } else { + for (auto& device_str : GetLocalDevices()) { + wait_devices.insert(device_str); + } + } + for (const std::string& device_str : wait_devices) { + TF_VLOG(3) << "Waiting for device execution for " << device_str + << " to finish"; + lock_device(device_str); + TF_VLOG(3) << "Waiting for device execution for " << device_str + << " to finish.. Done"; + } +} + +std::map IfrtComputationClient::GetMetrics() const { + // TODO(jonbolin): Add any PJRt-client-specific metrics here + return {}; +} + +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h new file mode 100644 index 00000000000..f35033a6e2d --- /dev/null +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -0,0 +1,241 @@ +#ifndef XLA_CLIENT_IFRT_COMPUTATION_CLIENT_H_ +#define XLA_CLIENT_IFRT_COMPUTATION_CLIENT_H_ + +#include + +#include +#include +#include + +#include "absl/types/span.h" +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/util.h" +#include "xla/client/xla_computation.h" +#include "xla/literal.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/shape.h" + +namespace torch_xla { +namespace runtime { + +class IfrtComputationClient : public ComputationClient { + public: + IfrtComputationClient(); + + DataPtr CreateDataPlaceholder(std::string device, xla::Shape shape) override; + + std::vector GetDataShards(DataPtr data) override; + + DataPtr GetDataShard(DataPtr data, size_t index) override; + + DataPtr WrapDataShards(const std::vector& shards, std::string device, + xla::Shape shape, xla::OpSharding sharding) override; + + std::optional GetDataSharding(DataPtr handle) override; + + std::vector TransferToServer( + absl::Span tensors) override; + + // Use XLA replication to re-assemble the sharded data. + DataPtr ReplicateShardedData(const DataPtr& handle); + + std::vector TransferFromServer( + absl::Span handles) override; + + DataPtr TransferShardsToServer(absl::Span tensor_shards, + std::string device, xla::Shape shape, + xla::OpSharding sharding) override; + + DataPtr CopyToDevice(DataPtr data, std::string dst) override; + + std::vector Compile( + std::vector instances) override; + + std::vector ExecuteComputation( + const Computation& computation, absl::Span arguments, + const std::string& device, + const ExecuteComputationOptions& options) override; + + std::vector> ExecuteReplicated( + const Computation& computation, + const std::vector>& arguments, + absl::Span devices, + const ExecuteReplicatedOptions& options) override; + + size_t GetNumDevices() const override; + + std::string GetDefaultDevice() const override; + + std::vector GetLocalDevices() const override; + + std::vector GetAllDevices() const override; + + int GetProcessIndex() const override { return client_->process_index(); }; + + int GetNumProcesses() const override; + + const absl::flat_hash_map< + std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& + GetDeviceAttributes(const std::string& device) override; + + void SetReplicationDevices( + std::shared_ptr> devices) override; + + std::shared_ptr> GetReplicationDevices() override; + + void PrepareToExit() override { return; }; + + void WaitDeviceOps(const std::vector& devices) override; + + std::map GetMetrics() const override; + + // NOT IMPLEMENTED + + MemoryInfo GetMemoryInfo(const std::string& device) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + }; + + private: + std::shared_ptr client_; + // global_ordinals_ tracks a map from PjRtDeviceId to the device's + // dense global ordinal. + std::unordered_map global_ordinals_; + std::unordered_map string_to_device_; + std::shared_ptr> replication_devices_; + std::unordered_map> + device_locks_; + + xla::PjRtDevice* StringToPjRtDevice(const std::string& device); + std::shared_lock lock_device_shared( + const std::string& device); + std::unique_lock lock_device(const std::string& device); + + std::string PjRtDeviceToString(xla::PjRtDevice* const device) const; + std::vector PjRtDevicesToString( + absl::Span devices) const; + + struct IfrtData : public Data { + IfrtData(std::string device, xla::Shape device_shape) + : Data(std::move(device), std::move(device_shape)) {} + + IfrtData(std::string device, xla::Shape device_shape, + tsl::RCReference buffer) + : Data(std::move(device), std::move(device_shape)), buffer(buffer) {} + + IfrtData(std::string device, tsl::RCReference buffer) + : Data(std::move(device), + xla::ShapeUtil::MakeShape( + xla::ifrt::ToPrimitiveType(buffer->dtype()).value(), + buffer->shape().dims())), + buffer(buffer) {} + + Handle GetHandle() override { + XLA_CHECK(HasValue()) + << "buffer with shape " << shape().ToString() << " on device " + << device() << (buffer == nullptr ? " is null" : " is deleted"); + return reinterpret_cast(buffer.get()); + }; + void Assign(const torch::lazy::BackendData& data) override; + bool HasValue() const override { + return buffer != nullptr; // TODO: && !buffer->IsDeleted(); + }; + + bool HasSharding() const override { return false; } + + xla::OpSharding GetSharding() const override { + XLA_ERROR() << "GetSharding should not be called on IfrtData, check " + "HasSharding first"; + return xla::OpSharding(); + } + + std::string ToString() const override { + std::stringstream ss; + ss << "XLAData: \n"; + ss << " Data Device: " << device() << "\n"; + ss << " Data Shape: " << shape().ToString() << "\n"; + ss << " Data Handle: "; + if (HasValue()) { + ss << reinterpret_cast(buffer.get()) << "\n"; + } else { + ss << "None\n"; + } + return ss.str(); + } + + tsl::RCReference buffer; + }; + + // struct PjRtShardedData : public Data { + // PjRtShardedData(std::string device, xla::Shape shape) = delete; + + // PjRtShardedData(std::string device, xla::Shape shape, + // std::vector> shards, + // xla::OpSharding sharding) + // : Data(std::move(device), std::move(shape)), + // shards(shards), + // sharding(sharding) {} + + // Handle GetHandle() override { + // // Always returns `Handle` of the first shard. + // return shards[0]->GetHandle(); + // } + + // void Assign(const torch::lazy::BackendData& data) override { + // const PjRtShardedData& pjrt_sharded_data = + // dynamic_cast(data); + // if (&pjrt_sharded_data != this) { + // shards = std::move(pjrt_sharded_data.shards); + // } + // } + + // bool HasValue() const override { + // if (shards.empty()) { + // return false; + // } + + // for (auto& shard : shards) { + // if (!shard->HasValue()) { + // return false; + // } + // } + // return true; + // } + + // std::string ToString() const override { + // std::stringstream ss; + // ss << "XLAShardedData: \n"; + // ss << " Data Device: " << device() << "\n"; + // ss << " Data Shape: " << shape().ToString() << "\n"; + // ss << " OpSharding: " + // << xla::HloSharding::FromProto(sharding)->ToString() << "\n"; + // ss << " NumShards: " << shards.size() << "\n"; + // return ss.str(); + // } + + // bool HasSharding() const override { return true; } + + // xla::OpSharding GetSharding() const override { return sharding; } + + // std::vector> shards; + // xla::OpSharding sharding; + // }; + + struct IfrtComputation : public Computation { + IfrtComputation(xla::XlaComputation computation, + std::vector devices, + std::unique_ptr executable) + : Computation(std::move(computation), std::move(devices)), + executable(std::move(executable)) {} + + std::unique_ptr executable; + }; +}; + +} // namespace runtime +} // namespace torch_xla +#endif // XLA_CLIENT_IFRT_COMPUTATION_CLIENT_H_ diff --git a/torch_xla/csrc/runtime/runtime.cc b/torch_xla/csrc/runtime/runtime.cc index e2f69c44e47..5d443d29d29 100644 --- a/torch_xla/csrc/runtime/runtime.cc +++ b/torch_xla/csrc/runtime/runtime.cc @@ -4,6 +4,7 @@ #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/pjrt_computation_client.h" +#include "torch_xla/csrc/runtime/ifrt_computation_client.h" #include "tsl/platform/stacktrace_handler.h" namespace torch_xla { @@ -19,8 +20,14 @@ ComputationClient* GetComputationClient() { std::unique_ptr client; + static bool use_ifrt = sys_util::GetEnvBool("XLA_USE_IFRT", false); + ComputationClient* client; if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") { - client = std::make_unique(); + if (use_ifrt) { + client = std::make_unique(); + } else { + client = std::make_unique(); + } } else { XLA_ERROR() << "$PJRT_DEVICE is not set." << std::endl; } From c7e48b467178e11a8ee9f0bea9eb1269d927eab0 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 5 Oct 2023 17:40:57 +0000 Subject: [PATCH 02/33] remove comment --- torch_xla/csrc/runtime/BUILD | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index e38baa03a93..62d043308b0 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -81,8 +81,6 @@ cc_library( ], deps = [ ":computation_client", - # TODO: why do I need this? - # ":pjrt_computation_client", ":debug_macros", ":env_vars", ":multi_wait", From 131246525b0e38a7a410e446b2238a7095cbec5e Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 5 Oct 2023 17:41:21 +0000 Subject: [PATCH 03/33] formatting --- .../csrc/runtime/ifrt_computation_client.cc | 49 ++++++++++++------- .../csrc/runtime/ifrt_computation_client.h | 8 +-- torch_xla/csrc/runtime/runtime.cc | 2 +- 3 files changed, 37 insertions(+), 22 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 60ff8056c68..6d556760dc3 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -110,14 +110,16 @@ IfrtComputationClient::IfrtComputationClient() { TF_VLOG(1) << "Initializing PjRt CPU client..."; bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true); int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1); - client_ = xla::ifrt::PjRtClient::Create(std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value())); + client_ = xla::ifrt::PjRtClient::Create( + std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value())); } else if (device_type == "TPU" || device_type == "TPU_C_API") { TF_VLOG(1) << "Initializing TFRT TPU client..."; XLA_CHECK_OK(pjrt::LoadPjrtPlugin( "tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so"))); tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); XLA_CHECK(tpu_status.ok()); - client_ = xla::ifrt::PjRtClient::Create(std::move(xla::GetCApiClient("TPU").value())); + client_ = xla::ifrt::PjRtClient::Create( + std::move(xla::GetCApiClient("TPU").value())); } else { XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, device_type); @@ -240,11 +242,13 @@ std::vector IfrtComputationClient::TransferToServer( client_ ->MakeArrayFromHostBuffer( literal_pointer->untyped_data(), - xla::ifrt::ToDType(literal_pointer->shape().element_type()).value(), + xla::ifrt::ToDType(literal_pointer->shape().element_type()) + .value(), xla::ifrt::Shape(literal_pointer->shape().dimensions()), byte_strides, // TODO: what is MemoryKind? - xla::ifrt::SingleDeviceSharding::Create(pjrt_device, xla::ifrt::MemoryKind()), + xla::ifrt::SingleDeviceSharding::Create( + pjrt_device, xla::ifrt::MemoryKind()), xla::PjRtClient::HostBufferSemantics:: kImmutableUntilTransferCompletes, [literal{std::move(literal)}]() { /* frees literal */ }) @@ -328,7 +332,8 @@ ComputationClient::DataPtr IfrtComputationClient::ReplicateShardedData( // xla::ConstantR0(&builder, 0), shape.element_type()); // xla::XlaOp y = xla::Add(x, scalar_zero_op); // auto instruction = XlaBuilderFriend::GetInstruction(y); - // *instruction->mutable_sharding() = xla::HloSharding::Replicate().ToProto(); + // *instruction->mutable_sharding() = + // xla::HloSharding::Replicate().ToProto(); // xla::XlaComputation computation = // ConsumeValue(builder.Build(/*remove_dynamic_dimensions=*/false)); @@ -352,7 +357,8 @@ ComputationClient::DataPtr IfrtComputationClient::ReplicateShardedData( // auto device_index = build_index_map(GetLocalDevices()); // std::vector> arguments_by_device( - // GetLocalDevices().size(), std::vector(1)); + // GetLocalDevices().size(), + // std::vector(1)); // for (auto shard : shards) { // std::vector device_spec = // absl::StrSplit(shard->device(), ':'); @@ -394,12 +400,15 @@ std::vector IfrtComputationClient::TransferFromServer( // TODO: handle dynamic shapes auto& literal = literals.emplace_back( - xla::ShapeUtil::DeviceShapeToHostShape(pjrt_data->shape())); + xla::ShapeUtil::DeviceShapeToHostShape(pjrt_data->shape())); std::vector byte_strides(literal.shape().dimensions_size()); XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal.shape(), absl::MakeSpan(byte_strides))); - XLA_CHECK_OK(pjrt_data->buffer->CopyToHostBuffer( - literal.untyped_data(), byte_strides, xla::ifrt::ArrayCopySemantics::kAlwaysCopy).Await()); + XLA_CHECK_OK( + pjrt_data->buffer + ->CopyToHostBuffer(literal.untyped_data(), byte_strides, + xla::ifrt::ArrayCopySemantics::kAlwaysCopy) + .Await()); total_size += literal.size_bytes(); } @@ -466,9 +475,10 @@ std::vector IfrtComputationClient::Compile( mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); torch_xla::runtime::ConvertHloToStableHlo( instance.computation.mutable_proto(), &mlir_module); - std::unique_ptr executable = ConsumeValue(client_->GetDefaultCompiler()->Compile( - std::make_unique(std::move(mlir_module)), - std::make_unique(compile_options))); + std::unique_ptr executable = + ConsumeValue(client_->GetDefaultCompiler()->Compile( + std::make_unique(std::move(mlir_module)), + std::make_unique(compile_options))); StableHloCompileCounter()->AddValue(1); const auto& hlo_modules = ConsumeValue(executable->GetHloModules()); @@ -595,7 +605,8 @@ IfrtComputationClient::ExecuteReplicated( // << "ExecuteReplicated over " << devices.size() << " devices, but " // << arguments.size() << " arguments devices."; // auto mwait_argument = std::make_shared(devices.size()); - // std::vector> argument_handles(devices.size()); + // std::vector> + // argument_handles(devices.size()); // { // tsl::profiler::TraceMe activity( // "IfrtComputationClient::ExecuteReplicated_argument_handle", @@ -603,11 +614,13 @@ IfrtComputationClient::ExecuteReplicated( // for (int32_t i = 0; i < devices.size(); ++i) { // auto buffer_converter = [&, i]() { // xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); - // XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); + // XLA_CHECK(pjrt_device->IsAddressable()) << + // pjrt_device->DebugString(); // std::vector buffers; // for (auto& argument : arguments[i]) { - // const PjRtData* pjrt_data = dynamic_cast(argument.get()); + // const PjRtData* pjrt_data = + // dynamic_cast(argument.get()); // XLA_CHECK(pjrt_device == pjrt_data->buffer->device()) // << pjrt_device->DebugString() << " vs " @@ -695,10 +708,12 @@ IfrtComputationClient::ExecuteReplicated( // returned_futures[0].OnReady( // [timed, lock = std::move(lock)](xla::Status unused) mutable { // timed.reset(); - // TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady finished"; + // TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady + // finished"; // }); // }; - // env::ScheduleIoClosure(util::MultiWait::Completer(mwait, std::move(lockfn))); + // env::ScheduleIoClosure(util::MultiWait::Completer(mwait, + // std::move(lockfn))); // TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results " // << "with dimensions [" << absl::StrJoin(dims, ",") << "]."; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index f35033a6e2d..e0b8c63c44c 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -130,8 +130,8 @@ class IfrtComputationClient : public ComputationClient { IfrtData(std::string device, tsl::RCReference buffer) : Data(std::move(device), xla::ShapeUtil::MakeShape( - xla::ifrt::ToPrimitiveType(buffer->dtype()).value(), - buffer->shape().dims())), + xla::ifrt::ToPrimitiveType(buffer->dtype()).value(), + buffer->shape().dims())), buffer(buffer) {} Handle GetHandle() override { @@ -142,14 +142,14 @@ class IfrtComputationClient : public ComputationClient { }; void Assign(const torch::lazy::BackendData& data) override; bool HasValue() const override { - return buffer != nullptr; // TODO: && !buffer->IsDeleted(); + return buffer != nullptr; // TODO: && !buffer->IsDeleted(); }; bool HasSharding() const override { return false; } xla::OpSharding GetSharding() const override { XLA_ERROR() << "GetSharding should not be called on IfrtData, check " - "HasSharding first"; + "HasSharding first"; return xla::OpSharding(); } diff --git a/torch_xla/csrc/runtime/runtime.cc b/torch_xla/csrc/runtime/runtime.cc index 5d443d29d29..4a4de61254c 100644 --- a/torch_xla/csrc/runtime/runtime.cc +++ b/torch_xla/csrc/runtime/runtime.cc @@ -3,8 +3,8 @@ #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/env_vars.h" -#include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/ifrt_computation_client.h" +#include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "tsl/platform/stacktrace_handler.h" namespace torch_xla { From a7ae0623cefbb5a6a5f56bb9494aaa7503a10ff7 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 11 Oct 2023 22:08:17 +0000 Subject: [PATCH 04/33] basic sharding --- .../csrc/runtime/ifrt_computation_client.cc | 47 ++++++++++++++++--- .../csrc/runtime/ifrt_computation_client.h | 6 +-- 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 6d556760dc3..0e06c34e44f 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -30,6 +30,7 @@ #include "xla/python/ifrt/sharding.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" #include "xla/shape.h" using xla::internal::XlaBuilderFriend; @@ -152,6 +153,15 @@ void IfrtComputationClient::IfrtData::Assign( } } +xla::OpSharding IfrtComputationClient::IfrtData::GetSharding() const { + // XLA_ERROR() << "GetSharding should not be called on IfrtData, check " + // "HasSharding first"; + const xla::ifrt::Sharding& sharding = buffer->sharding(); + auto hlo_sharding = dynamic_cast(sharding); + // TODO: why are we using the proto? + return hlo_sharding.xla_hlo_sharding().ToProto(); +} + ComputationClient::DataPtr IfrtComputationClient::CreateDataPlaceholder( std::string device, xla::Shape shape) { return std::make_shared(device, shape); @@ -267,7 +277,8 @@ std::vector IfrtComputationClient::TransferToServer( ComputationClient::DataPtr IfrtComputationClient::TransferShardsToServer( absl::Span tensor_shards, std::string device, xla::Shape shape, xla::OpSharding sharding) { - XLA_ERROR() << __FUNCTION__ << " not implemented"; + // TODO: completely ignoring OpSharding. Is that important? + // XLA_ERROR() << __FUNCTION__ << " not implemented"; // tsl::profiler::TraceMe activity( // "IfrtComputationClient::TransferShardsToServer", // tsl::profiler::TraceMeLevel::kInfo); @@ -276,15 +287,37 @@ ComputationClient::DataPtr IfrtComputationClient::TransferShardsToServer( // // issues observed in ShardingUtil::InputHandler, but because CopyToDevice // // directly copies buffers between devices using ICI, it can be much faster // // than transferring from the host to each device. - // auto data_shards = TransferToServer(tensor_shards); + auto data_shards = TransferToServer(tensor_shards); // std::vector> pjrt_data_shards; - // for (auto& shard : data_shards) { - // auto pjrt_shard = dynamic_cast(shard.get()); - // pjrt_data_shards.push_back(std::make_shared( - // pjrt_shard->device(), pjrt_shard->shape(), pjrt_shard->buffer)); - // } + std::vector> arrays; + std::vector shard_shapes; + for (auto& shard : data_shards) { + auto ifrt_shard = std::dynamic_pointer_cast(shard); + arrays.push_back(ifrt_shard->buffer); + shard_shapes.push_back(ifrt_shard->buffer->shape()); + // pjrt_data_shards.push_back(std::make_shared( + // pjrt_shard->device(), pjrt_shard->shape(), pjrt_shard->buffer)); + } // return std::make_shared(device, shape, pjrt_data_shards, // sharding); + xla::ifrt::Shape ifrt_shape(shape.dimensions()); + xla::ifrt::DeviceList devices_list({client_->addressable_devices().begin(), client_->addressable_devices().end()}); + std::unique_ptr ifrt_sharding = xla::ifrt::ConcreteSharding::Create( + devices_list, + xla::ifrt::MemoryKind(), + ifrt_shape, + shard_shapes + ); + // TODO: why doesn't HloSharding work? + // RuntimeError: Bad StatusOr access: INVALID_ARGUMENT: Only SingleDeviceSharding, OpaqueSharding, ConcreteSharding, ConcreteEvenSharding, and ShardingParamSharding are supported: sharding=HloSharding(memory_kind: (default), hlo_sharding: {devices=[1,4]0,1,2,3}) + // std::unique_ptr ifrt_sharding = xla::ifrt::HloSharding::Create( + // devices_list, + // xla::ifrt::MemoryKind(), + // xla::HloSharding::FromProto(sharding).value() + // ); + tsl::RCReference sharded_array = client_->AssembleArrayFromSingleDeviceArrays( + ifrt_shape, std::move(ifrt_sharding), absl::MakeSpan(arrays), xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value(); + return std::make_shared(device, shape, sharded_array); } ComputationClient::DataPtr IfrtComputationClient::CopyToDevice( diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index e0b8c63c44c..f3767ca4615 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -147,11 +147,7 @@ class IfrtComputationClient : public ComputationClient { bool HasSharding() const override { return false; } - xla::OpSharding GetSharding() const override { - XLA_ERROR() << "GetSharding should not be called on IfrtData, check " - "HasSharding first"; - return xla::OpSharding(); - } + xla::OpSharding GetSharding() const override; std::string ToString() const override { std::stringstream ss; From 835487f189bc494f58b1f02bf8e3cac156daeb18 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 12 Oct 2023 17:01:20 +0000 Subject: [PATCH 05/33] Add `xla::OpSharding` back as source of truth --- .../csrc/runtime/ifrt_computation_client.cc | 11 ++----- .../csrc/runtime/ifrt_computation_client.h | 32 +++++++++++++------ 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 0e06c34e44f..240341b3b04 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -277,7 +277,6 @@ std::vector IfrtComputationClient::TransferToServer( ComputationClient::DataPtr IfrtComputationClient::TransferShardsToServer( absl::Span tensor_shards, std::string device, xla::Shape shape, xla::OpSharding sharding) { - // TODO: completely ignoring OpSharding. Is that important? // XLA_ERROR() << __FUNCTION__ << " not implemented"; // tsl::profiler::TraceMe activity( // "IfrtComputationClient::TransferShardsToServer", @@ -288,18 +287,13 @@ ComputationClient::DataPtr IfrtComputationClient::TransferShardsToServer( // // directly copies buffers between devices using ICI, it can be much faster // // than transferring from the host to each device. auto data_shards = TransferToServer(tensor_shards); - // std::vector> pjrt_data_shards; std::vector> arrays; std::vector shard_shapes; for (auto& shard : data_shards) { auto ifrt_shard = std::dynamic_pointer_cast(shard); arrays.push_back(ifrt_shard->buffer); shard_shapes.push_back(ifrt_shard->buffer->shape()); - // pjrt_data_shards.push_back(std::make_shared( - // pjrt_shard->device(), pjrt_shard->shape(), pjrt_shard->buffer)); } - // return std::make_shared(device, shape, pjrt_data_shards, - // sharding); xla::ifrt::Shape ifrt_shape(shape.dimensions()); xla::ifrt::DeviceList devices_list({client_->addressable_devices().begin(), client_->addressable_devices().end()}); std::unique_ptr ifrt_sharding = xla::ifrt::ConcreteSharding::Create( @@ -308,8 +302,7 @@ ComputationClient::DataPtr IfrtComputationClient::TransferShardsToServer( ifrt_shape, shard_shapes ); - // TODO: why doesn't HloSharding work? - // RuntimeError: Bad StatusOr access: INVALID_ARGUMENT: Only SingleDeviceSharding, OpaqueSharding, ConcreteSharding, ConcreteEvenSharding, and ShardingParamSharding are supported: sharding=HloSharding(memory_kind: (default), hlo_sharding: {devices=[1,4]0,1,2,3}) + // TODO: Attach HloSharding instead when it is supported // std::unique_ptr ifrt_sharding = xla::ifrt::HloSharding::Create( // devices_list, // xla::ifrt::MemoryKind(), @@ -317,7 +310,7 @@ ComputationClient::DataPtr IfrtComputationClient::TransferShardsToServer( // ); tsl::RCReference sharded_array = client_->AssembleArrayFromSingleDeviceArrays( ifrt_shape, std::move(ifrt_sharding), absl::MakeSpan(arrays), xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value(); - return std::make_shared(device, shape, sharded_array); + return std::make_shared(device, shape, sharded_array, sharding); } ComputationClient::DataPtr IfrtComputationClient::CopyToDevice( diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index f3767ca4615..ab0fc198981 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -124,8 +124,9 @@ class IfrtComputationClient : public ComputationClient { : Data(std::move(device), std::move(device_shape)) {} IfrtData(std::string device, xla::Shape device_shape, - tsl::RCReference buffer) - : Data(std::move(device), std::move(device_shape)), buffer(buffer) {} + tsl::RCReference buffer, + std::optional sharding = std::nullopt) + : Data(std::move(device), std::move(device_shape)), buffer(buffer), sharding_(sharding) {} IfrtData(std::string device, tsl::RCReference buffer) : Data(std::move(device), @@ -145,24 +146,35 @@ class IfrtComputationClient : public ComputationClient { return buffer != nullptr; // TODO: && !buffer->IsDeleted(); }; - bool HasSharding() const override { return false; } + bool HasSharding() const override { return sharding_.has_value(); } xla::OpSharding GetSharding() const override; std::string ToString() const override { std::stringstream ss; - ss << "XLAData: \n"; - ss << " Data Device: " << device() << "\n"; - ss << " Data Shape: " << shape().ToString() << "\n"; - ss << " Data Handle: "; - if (HasValue()) { - ss << reinterpret_cast(buffer.get()) << "\n"; + + if (HasSharding()) { + ss << "XLAShardedData: \n"; + ss << " Data Device: " << device() << "\n"; + ss << " Data Shape: " << shape().ToString() << "\n"; + ss << " OpSharding: " + << xla::HloSharding::FromProto(*sharding_)->ToString() << "\n"; + ss << " NumShards: " << buffer->sharding().devices().size() << "\n"; } else { - ss << "None\n"; + ss << "XLAData: \n"; + ss << " Data Device: " << device() << "\n"; + ss << " Data Shape: " << shape().ToString() << "\n"; + ss << " Data Handle: "; + if (HasValue()) { + ss << reinterpret_cast(buffer.get()) << "\n"; + } else { + ss << "None\n"; + } } return ss.str(); } + std::optional sharding_; tsl::RCReference buffer; }; From a821a4d17e9eb20f0d58eaa60f08911d44bf190b Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 12 Oct 2023 22:33:04 +0000 Subject: [PATCH 06/33] wrapping and unwrapping sharded data --- .../csrc/runtime/ifrt_computation_client.cc | 115 +++++++++++------- 1 file changed, 72 insertions(+), 43 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 240341b3b04..f095c3c25d9 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -154,12 +154,8 @@ void IfrtComputationClient::IfrtData::Assign( } xla::OpSharding IfrtComputationClient::IfrtData::GetSharding() const { - // XLA_ERROR() << "GetSharding should not be called on IfrtData, check " - // "HasSharding first"; - const xla::ifrt::Sharding& sharding = buffer->sharding(); - auto hlo_sharding = dynamic_cast(sharding); - // TODO: why are we using the proto? - return hlo_sharding.xla_hlo_sharding().ToProto(); + XLA_CHECK(HasSharding()) << "Check HasSharding first"; + return *sharding_; } ComputationClient::DataPtr IfrtComputationClient::CreateDataPlaceholder( @@ -169,27 +165,28 @@ ComputationClient::DataPtr IfrtComputationClient::CreateDataPlaceholder( std::vector IfrtComputationClient::GetDataShards( ComputationClient::DataPtr data) { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - // tsl::profiler::TraceMe activity("IfrtComputationClient::GetDataShards", - // tsl::profiler::TraceMeLevel::kInfo); - // std::vector shards; - // if (PjRtShardedData* sharded_data = - // dynamic_cast(data.get())) { - // for (auto shard : sharded_data->shards) { - // shards.push_back(std::make_shared( - // shard->device(), shard->shape(), shard->buffer)); - // } - // } else { - // shards.push_back(data); - // } - // return shards; + tsl::profiler::TraceMe activity("IfrtComputationClient::GetDataShards", + tsl::profiler::TraceMeLevel::kInfo); + std::vector shards; + if (data->HasSharding()) { + auto ifrt_data = std::dynamic_pointer_cast(data); + std::vector> arrays = ifrt_data->buffer->DisassembleIntoSingleDeviceArrays(xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value(); + + for (auto array : arrays) { + shards.push_back(std::make_shared( + PjRtDeviceToString(array->sharding().devices()[0]), array)); + } + } else { + shards.push_back(data); + } + return shards; } ComputationClient::DataPtr IfrtComputationClient::GetDataShard( ComputationClient::DataPtr data, size_t index) { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - // tsl::profiler::TraceMe activity("IfrtComputationClient::GetDataShard", - // tsl::profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMe activity("IfrtComputationClient::GetDataShard", + tsl::profiler::TraceMeLevel::kInfo); + return GetDataShards(data)[index]; // if (PjRtShardedData* sharded_data = // dynamic_cast(data.get())) { // XLA_CHECK_LE(index, sharded_data->shards.size()) @@ -206,7 +203,7 @@ ComputationClient::DataPtr IfrtComputationClient::GetDataShard( ComputationClient::DataPtr IfrtComputationClient::WrapDataShards( const std::vector& shards, std::string device, xla::Shape shape, xla::OpSharding sharding) { - XLA_ERROR() << __FUNCTION__ << " not implemented"; + // XLA_ERROR() << __FUNCTION__ << " not implemented"; // std::vector> pjrt_data_shards; // pjrt_data_shards.reserve(shards.size()); // for (auto& shard : shards) { @@ -217,15 +214,42 @@ ComputationClient::DataPtr IfrtComputationClient::WrapDataShards( // } // return std::make_shared(device, shape, pjrt_data_shards, // sharding); + // TODO: implement CreateDataPlaceholder for sharded data + if (shards.size() == 0) { + TF_LOG(WARNING) << "creating sharded placeholder"; + return std::make_shared(device, shape, tsl::RCReference(), sharding); + } + std::vector> arrays; + std::vector shard_shapes; + for (auto& shard : shards) { + auto ifrt_shard = std::dynamic_pointer_cast(shard); + arrays.push_back(ifrt_shard->buffer); + shard_shapes.push_back(ifrt_shard->buffer->shape()); + } + xla::ifrt::Shape ifrt_shape(shape.dimensions()); + xla::ifrt::DeviceList devices_list({client_->addressable_devices().begin(), client_->addressable_devices().end()}); + XLA_CHECK_EQ(shard_shapes.size(), devices_list.size()); + std::unique_ptr ifrt_sharding = xla::ifrt::ConcreteSharding::Create( + devices_list, + xla::ifrt::MemoryKind(), + ifrt_shape, + shard_shapes + ); + // TODO: Attach HloSharding instead when it is supported + // std::unique_ptr ifrt_sharding = xla::ifrt::HloSharding::Create( + // devices_list, + // xla::ifrt::MemoryKind(), + // xla::HloSharding::FromProto(sharding).value() + // ); + tsl::RCReference sharded_array = client_->AssembleArrayFromSingleDeviceArrays( + ifrt_shape, std::move(ifrt_sharding), absl::MakeSpan(arrays), xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value(); + return std::make_shared(device, shape, sharded_array, sharding); } std::optional IfrtComputationClient::GetDataSharding( DataPtr handle) { - return std::nullopt; - // if (auto sharded_data = dynamic_cast(handle.get())) { - // return sharded_data->GetSharding(); - // } - // return std::optional(); + auto ifrt_data = std::dynamic_pointer_cast(handle); + return ifrt_data->sharding_; } std::vector IfrtComputationClient::TransferToServer( @@ -424,6 +448,10 @@ std::vector IfrtComputationClient::TransferFromServer( // auto new_handle = ReplicateShardedData(handle); auto pjrt_data = std::dynamic_pointer_cast(handle); + // TODO: this is probably wrong for MP + auto replicated_array = pjrt_data->buffer->FullyReplicatedShard( + xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value(); + // TODO: handle dynamic shapes auto& literal = literals.emplace_back( xla::ShapeUtil::DeviceShapeToHostShape(pjrt_data->shape())); @@ -431,7 +459,7 @@ std::vector IfrtComputationClient::TransferFromServer( XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal.shape(), absl::MakeSpan(byte_strides))); XLA_CHECK_OK( - pjrt_data->buffer + replicated_array ->CopyToHostBuffer(literal.untyped_data(), byte_strides, xla::ifrt::ArrayCopySemantics::kAlwaysCopy) .Await()); @@ -617,22 +645,23 @@ IfrtComputationClient::ExecuteReplicated( const std::vector>& arguments, absl::Span devices, const ExecuteReplicatedOptions& options) { - XLA_ERROR() << __FUNCTION__ << " not implemented"; + // XLA_ERROR() << __FUNCTION__ << " not implemented"; // Shared ownership of the timed section ensures that it will only get logged // once both `ExecuteReplicated` and the async work in `Execute` are // complete; a copy is held from the lambda that releases it when done. - // auto timed = - // std::make_shared(ExecuteReplicatedMetric()); - // tsl::profiler::TraceMe activity("IfrtComputationClient::ExecuteReplicated", - // tsl::profiler::TraceMeLevel::kInfo); - // const PjRtComputation& pjrt_computation = - // dynamic_cast(computation); - // XLA_CHECK(devices.size() == arguments.size()) - // << "ExecuteReplicated over " << devices.size() << " devices, but " - // << arguments.size() << " arguments devices."; - // auto mwait_argument = std::make_shared(devices.size()); - // std::vector> - // argument_handles(devices.size()); + auto timed = + std::make_shared(ExecuteReplicatedMetric()); + tsl::profiler::TraceMe activity("IfrtComputationClient::ExecuteReplicated", + tsl::profiler::TraceMeLevel::kInfo); + const IfrtComputation& ifrt_computation = + dynamic_cast(computation); + XLA_CHECK(devices.size() == arguments.size()) + << "ExecuteReplicated over " << devices.size() << " devices, but " + << arguments.size() << " arguments devices."; + auto mwait_argument = std::make_shared(devices.size()); + std::vector> argument_handles; + + // TODO: parallelize again if necessary // { // tsl::profiler::TraceMe activity( // "IfrtComputationClient::ExecuteReplicated_argument_handle", From 2e8842bae10085260c023e1a16ebbea9804abd84 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 12 Oct 2023 23:38:06 +0000 Subject: [PATCH 07/33] ExecuteReplicated --- .../csrc/runtime/ifrt_computation_client.cc | 51 ++++++++++++++++--- .../csrc/runtime/ifrt_computation_client.h | 4 +- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index f095c3c25d9..acd98f6aada 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -643,25 +643,64 @@ std::vector> IfrtComputationClient::ExecuteReplicated( const ComputationClient::Computation& computation, const std::vector>& arguments, + // TODO: devices isn't doing anything helpful here absl::Span devices, const ExecuteReplicatedOptions& options) { // XLA_ERROR() << __FUNCTION__ << " not implemented"; // Shared ownership of the timed section ensures that it will only get logged // once both `ExecuteReplicated` and the async work in `Execute` are // complete; a copy is held from the lambda that releases it when done. + // TODO: fix timing auto timed = std::make_shared(ExecuteReplicatedMetric()); tsl::profiler::TraceMe activity("IfrtComputationClient::ExecuteReplicated", tsl::profiler::TraceMeLevel::kInfo); const IfrtComputation& ifrt_computation = dynamic_cast(computation); - XLA_CHECK(devices.size() == arguments.size()) - << "ExecuteReplicated over " << devices.size() << " devices, but " - << arguments.size() << " arguments devices."; - auto mwait_argument = std::make_shared(devices.size()); - std::vector> argument_handles; - + // XLA_CHECK(devices.size() == arguments.size()) + // << "ExecuteReplicated over " << devices.size() << " devices, but " + // << arguments.size() << " arguments devices."; // TODO: parallelize again if necessary + std::vector> argument_handles(arguments[0].size()); + for (int32_t i = 0; i < arguments[0].size(); ++i) { + auto ifrt_data = std::dynamic_pointer_cast(arguments[0][i]); + argument_handles[i] = ifrt_data->buffer; + } + + xla::ExecuteOptions execute_options; + execute_options.untuple_result = options.explode_tuple; + execute_options.strict_shape_checking = true; + // TODO(yeounoh) currently only support single-slice execution + execute_options.multi_slice_config = nullptr; + + xla::ifrt::LoadedExecutable::ExecuteResult result = + ifrt_computation.executable + ->Execute(absl::MakeSpan(argument_handles), execute_options, std::nullopt) + .value(); + + xla::ifrt::Future returned_future = result.status; + auto results = result.outputs; + + std::vector data_handles; + data_handles.reserve(results.size()); + + XLA_CHECK(ifrt_computation.executable->GetOutputShardings().has_value()); + auto output_shardings = *(ifrt_computation.executable->GetOutputShardings()); + XLA_CHECK_EQ(output_shardings.size(), results.size()); + + for (int32_t i = 0; i < results.size(); ++i) { + xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); + XLA_CHECK(pjrt_device->IsAddressable()) + << pjrt_device->DebugString() << " is not addressable."; + + std::shared_ptr data = + std::make_shared(devices[i], results[i], output_shardings[i]); + data_handles.push_back(data); + } + + // TODO: any useful debug logging + return {data_handles}; + // { // tsl::profiler::TraceMe activity( // "IfrtComputationClient::ExecuteReplicated_argument_handle", diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index ab0fc198981..b2be4cd91dd 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -128,12 +128,12 @@ class IfrtComputationClient : public ComputationClient { std::optional sharding = std::nullopt) : Data(std::move(device), std::move(device_shape)), buffer(buffer), sharding_(sharding) {} - IfrtData(std::string device, tsl::RCReference buffer) + IfrtData(std::string device, tsl::RCReference buffer, std::optional sharding = std::nullopt) : Data(std::move(device), xla::ShapeUtil::MakeShape( xla::ifrt::ToPrimitiveType(buffer->dtype()).value(), buffer->shape().dims())), - buffer(buffer) {} + buffer(buffer), sharding_(sharding) {} Handle GetHandle() override { XLA_CHECK(HasValue()) From d75c040bc6016026f67fe3af7afb1e9bb797164f Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 18 Oct 2023 17:04:34 +0000 Subject: [PATCH 08/33] [revert later] try resharding --- .../csrc/runtime/ifrt_computation_client.cc | 32 +++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index acd98f6aada..62c03bb27cf 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -270,6 +270,8 @@ std::vector IfrtComputationClient::TransferToServer( absl::MakeSpan(byte_strides))); total_size += literal->size_bytes(); + std::cout << "transfer to " << tensor.device << " " << literal->ToString() << std::endl; + // Avoid use-after-free on `literal` due to unsequenced move and use. xla::Literal* literal_pointer = literal.get(); tsl::RCReference buffer = @@ -447,11 +449,32 @@ std::vector IfrtComputationClient::TransferFromServer( // is not sharded, then it is a no-op. // auto new_handle = ReplicateShardedData(handle); auto pjrt_data = std::dynamic_pointer_cast(handle); + std::cout << "sharded " << pjrt_data->buffer->shape().DebugString() << std::endl; // TODO: this is probably wrong for MP - auto replicated_array = pjrt_data->buffer->FullyReplicatedShard( - xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value(); - + xla::ifrt::DeviceList devices_list({client_->addressable_devices().begin(), client_->addressable_devices().end()}); + // auto replicated_array = pjrt_data->buffer->FullyReplicatedShard(xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value(); + auto replicated_array = pjrt_data->buffer->Reshard( + xla::ifrt::ConcreteEvenSharding::Create( + pjrt_data->buffer->sharding().devices(), + xla::ifrt::MemoryKind(), + pjrt_data->buffer->shape(), + pjrt_data->buffer->shape() + ), + xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value()->DisassembleIntoSingleDeviceArrays(xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value()[0]; + std::cout << "replicated " << replicated_array->shape().DebugString() << std::endl; + + // ->Reshard( + // xla::ifrt::ConcreteEvenSharding::Create( + // devices_list, + // xla::ifrt::MemoryKind(), + // pjrt_data->buffer->shape(), + // pjrt_data->buffer->shape() + // ), + // xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value() + // auto& literal = literals.emplace_back(xla::ShapeUtil::MakeShape( + // xla::ifrt::ToPrimitiveType(replicated_array->dtype()).value(), + // replicated_array->shape().dims())); // TODO: handle dynamic shapes auto& literal = literals.emplace_back( xla::ShapeUtil::DeviceShapeToHostShape(pjrt_data->shape())); @@ -465,6 +488,7 @@ std::vector IfrtComputationClient::TransferFromServer( .Await()); total_size += literal.size_bytes(); + std::cout << literal.ToString() << std::endl; } InboundDataMetric()->AddSample(total_size); @@ -688,11 +712,13 @@ IfrtComputationClient::ExecuteReplicated( auto output_shardings = *(ifrt_computation.executable->GetOutputShardings()); XLA_CHECK_EQ(output_shardings.size(), results.size()); + std::cout << "output" << std::endl; for (int32_t i = 0; i < results.size(); ++i) { xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString() << " is not addressable."; + std::cout << results[i]->sharding().DebugString() << std::endl; std::shared_ptr data = std::make_shared(devices[i], results[i], output_shardings[i]); data_handles.push_back(data); From 5544fc0f26fddf540bd0b31cd30cf7c6efd3b5ac Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 19 Oct 2023 16:23:37 +0000 Subject: [PATCH 09/33] Revert "[revert later] try resharding" This reverts commit 7d52f671f1ed28215dc363706ce45b401000d2c7. --- .../csrc/runtime/ifrt_computation_client.cc | 32 ++----------------- 1 file changed, 3 insertions(+), 29 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 62c03bb27cf..acd98f6aada 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -270,8 +270,6 @@ std::vector IfrtComputationClient::TransferToServer( absl::MakeSpan(byte_strides))); total_size += literal->size_bytes(); - std::cout << "transfer to " << tensor.device << " " << literal->ToString() << std::endl; - // Avoid use-after-free on `literal` due to unsequenced move and use. xla::Literal* literal_pointer = literal.get(); tsl::RCReference buffer = @@ -449,32 +447,11 @@ std::vector IfrtComputationClient::TransferFromServer( // is not sharded, then it is a no-op. // auto new_handle = ReplicateShardedData(handle); auto pjrt_data = std::dynamic_pointer_cast(handle); - std::cout << "sharded " << pjrt_data->buffer->shape().DebugString() << std::endl; // TODO: this is probably wrong for MP - xla::ifrt::DeviceList devices_list({client_->addressable_devices().begin(), client_->addressable_devices().end()}); - // auto replicated_array = pjrt_data->buffer->FullyReplicatedShard(xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value(); - auto replicated_array = pjrt_data->buffer->Reshard( - xla::ifrt::ConcreteEvenSharding::Create( - pjrt_data->buffer->sharding().devices(), - xla::ifrt::MemoryKind(), - pjrt_data->buffer->shape(), - pjrt_data->buffer->shape() - ), - xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value()->DisassembleIntoSingleDeviceArrays(xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value()[0]; - std::cout << "replicated " << replicated_array->shape().DebugString() << std::endl; - - // ->Reshard( - // xla::ifrt::ConcreteEvenSharding::Create( - // devices_list, - // xla::ifrt::MemoryKind(), - // pjrt_data->buffer->shape(), - // pjrt_data->buffer->shape() - // ), - // xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value() - // auto& literal = literals.emplace_back(xla::ShapeUtil::MakeShape( - // xla::ifrt::ToPrimitiveType(replicated_array->dtype()).value(), - // replicated_array->shape().dims())); + auto replicated_array = pjrt_data->buffer->FullyReplicatedShard( + xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value(); + // TODO: handle dynamic shapes auto& literal = literals.emplace_back( xla::ShapeUtil::DeviceShapeToHostShape(pjrt_data->shape())); @@ -488,7 +465,6 @@ std::vector IfrtComputationClient::TransferFromServer( .Await()); total_size += literal.size_bytes(); - std::cout << literal.ToString() << std::endl; } InboundDataMetric()->AddSample(total_size); @@ -712,13 +688,11 @@ IfrtComputationClient::ExecuteReplicated( auto output_shardings = *(ifrt_computation.executable->GetOutputShardings()); XLA_CHECK_EQ(output_shardings.size(), results.size()); - std::cout << "output" << std::endl; for (int32_t i = 0; i < results.size(); ++i) { xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString() << " is not addressable."; - std::cout << results[i]->sharding().DebugString() << std::endl; std::shared_ptr data = std::make_shared(devices[i], results[i], output_shardings[i]); data_handles.push_back(data); From 03481e023a6de1ddba1599a34eec6c11d1f70b0c Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 19 Oct 2023 17:47:01 +0000 Subject: [PATCH 10/33] reassemble sharded outputs --- .../csrc/runtime/ifrt_computation_client.cc | 163 +++++++++--------- .../csrc/runtime/ifrt_computation_client.h | 4 +- 2 files changed, 88 insertions(+), 79 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index acd98f6aada..ae9fc5a9ed2 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -358,80 +358,87 @@ ComputationClient::DataPtr IfrtComputationClient::CopyToDevice( // std::move(status_or.value())); } -ComputationClient::DataPtr IfrtComputationClient::ReplicateShardedData( - const ComputationClient::DataPtr& handle) { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - // if (PjRtShardedData* sharded_data = - // dynamic_cast(handle.get())) { - // XLA_COUNTER("ReplicateShardedData", 1); - // TF_VLOG(1) << "ReplicateShardedData (handle=" << handle->GetHandle() - // << ", shape=" << handle->shape() << ")"; - // if (sharded_data->GetSharding().type() == xla::OpSharding::REPLICATED) { - // // Data is replicated, return the first shard - // return sharded_data->shards[0]; - // } - // xla::XlaBuilder builder("ReplicateShardedData"); - // xla::Shape shape = sharded_data->shape(); - // builder.SetSharding(sharded_data->GetSharding()); - - // // perform a simple identity calculation to reassemble the input as - // // replicated output. - // xla::XlaOp x = xla::Parameter(&builder, 0, shape, "p0"); - // builder.SetSharding(xla::HloSharding::Replicate().ToProto()); - // xla::XlaOp scalar_zero_op = xla::ConvertElementType( - // xla::ConstantR0(&builder, 0), shape.element_type()); - // xla::XlaOp y = xla::Add(x, scalar_zero_op); - // auto instruction = XlaBuilderFriend::GetInstruction(y); - // *instruction->mutable_sharding() = - // xla::HloSharding::Replicate().ToProto(); - - // xla::XlaComputation computation = - // ConsumeValue(builder.Build(/*remove_dynamic_dimensions=*/false)); - // xla::ProgramShape program_shape = - // ConsumeValue(computation.GetProgramShape()); - - // std::string device = GetDefaultDevice(); - // std::vector - // instances; - // instances.push_back({std::move(computation), device, - // GetCompilationDevices(device, {}), &shape, - // /*should_wrap_parameter=*/false, - // /*is_sharded=*/true, - // /*allow_spmd_sharding_propagation_to_output=*/false}); - // std::vector< - // std::shared_ptr> - // computations = Compile(std::move(instances)); - - // auto shards = sharded_data->shards; - // XLA_CHECK_EQ(shards.size(), GetLocalDevices().size()); - // auto device_index = build_index_map(GetLocalDevices()); - - // std::vector> arguments_by_device( - // GetLocalDevices().size(), - // std::vector(1)); - // for (auto shard : shards) { - // std::vector device_spec = - // absl::StrSplit(shard->device(), ':'); - // XLA_CHECK_EQ(device_spec.size(), 2) - // << "Invalid device specification: " << shard->device(); - // int device_i = device_index[std::stoi(device_spec[1])]; - // TF_VLOG(3) << shard->device() << " is mapped to local device index " - // << device_i; - // arguments_by_device[device_i][0] = shard; - // } - // torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions - // execute_options; - // auto sharded_results = - // ExecuteReplicated(*computations.front(), arguments_by_device, - // GetLocalDevices(), execute_options); - // XLA_CHECK(sharded_results.size() > 0) - // << "empty ExecuteReplicated results returned."; - // XLA_CHECK(sharded_results[0].size() == 1) - // << "Wrong number of outputs, expected: 1, actual: " - // << sharded_results[0].size(); - // return sharded_results[0][0]; +tsl::RCReference IfrtComputationClient::ReplicateShardedData( + const std::shared_ptr handle) { + + if (handle->buffer->sharding().devices().size() == 1) { + return handle->buffer; + } + + XLA_COUNTER("ReplicateShardedData", 1); + TF_VLOG(1) << "ReplicateShardedData (handle=" // TODO: why isn't GetHandle const? << handle->GetHandle() + << ", shape=" << handle->shape() << ")"; + // if (sharded_data->GetSharding().type() == xla::OpSharding::REPLICATED) { + // // Data is replicated, return the first shard + // return sharded_data->shards[0]; + // } + xla::XlaBuilder builder("ReplicateShardedData"); + xla::Shape shape = handle->shape(); + builder.SetSharding(handle->GetSharding()); + + // perform a simple identity calculation to reassemble the input as + // replicated output. + xla::XlaOp x = xla::Parameter(&builder, 0, shape, "p0"); + builder.SetSharding(xla::HloSharding::Replicate().ToProto()); + xla::XlaOp scalar_zero_op = xla::ConvertElementType( + xla::ConstantR0(&builder, 0), shape.element_type()); + xla::XlaOp y = xla::Add(x, scalar_zero_op); + auto instruction = XlaBuilderFriend::GetInstruction(y); + *instruction->mutable_sharding() = + xla::HloSharding::Replicate().ToProto(); + + xla::XlaComputation computation = + ConsumeValue(builder.Build(/*remove_dynamic_dimensions=*/false)); + xla::ProgramShape program_shape = + ConsumeValue(computation.GetProgramShape()); + + std::string device = GetDefaultDevice(); + std::vector + instances; + instances.push_back({std::move(computation), device, + GetCompilationDevices(device, {}), &shape, + /*should_wrap_parameter=*/false, + /*is_sharded=*/true, + /*allow_spmd_sharding_propagation_to_output=*/false}); + std::vector< + std::shared_ptr> + computations = Compile(std::move(instances)); + + // auto shards = sharded_data->shards; + XLA_CHECK_EQ(handle->buffer->sharding().devices().size(), GetLocalDevices().size()); + // auto device_index = build_index_map(GetLocalDevices()); + + // std::vector> arguments_by_device( + // GetLocalDevices().size(), + // std::vector(1)); + // for (auto shard : shards) { + // std::vector device_spec = + // absl::StrSplit(shard->device(), ':'); + // XLA_CHECK_EQ(device_spec.size(), 2) + // << "Invalid device specification: " << shard->device(); + // int device_i = device_index[std::stoi(device_spec[1])]; + // TF_VLOG(3) << shard->device() << " is mapped to local device index " + // << device_i; + // arguments_by_device[device_i][0] = shard; // } - // return handle; + torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions + execute_options; + + // TODO: fix const plumbing for real + DataPtr handle_but_not_const = std::make_shared(handle->device(), handle->buffer, handle->GetSharding()); + // std::vector> args; // TODO: figure out brace init = {{handle}}; + // args.push_back({}); + // args[0].push_back(handle_but_not_const); + auto sharded_results = + ExecuteReplicated(*computations.front(), {{handle_but_not_const}}, + GetLocalDevices(), execute_options); + auto replicated_output = std::dynamic_pointer_cast(sharded_results[0][0])->buffer->FullyReplicatedShard(xla::ifrt::ArrayCopySemantics::kAlwaysCopy); + // XLA_CHECK(sharded_results.size() > 0) + // << "empty ExecuteReplicated results returned."; + // XLA_CHECK(sharded_results[0].size() == 1) + // << "Wrong number of outputs, expected: 1, actual: " + // << sharded_results[0].size(); + return *replicated_output; } std::vector IfrtComputationClient::TransferFromServer( @@ -445,16 +452,16 @@ std::vector IfrtComputationClient::TransferFromServer( for (auto handle : handles) { // Use XLA replication to reassemble the sharded data. If input handle // is not sharded, then it is a no-op. - // auto new_handle = ReplicateShardedData(handle); - auto pjrt_data = std::dynamic_pointer_cast(handle); + auto ifrt_data = std::dynamic_pointer_cast(handle); + tsl::RCReference replicated_array = ReplicateShardedData(ifrt_data); // TODO: this is probably wrong for MP - auto replicated_array = pjrt_data->buffer->FullyReplicatedShard( - xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value(); + // auto replicated_array = pjrt_data->buffer->FullyReplicatedShard( + // xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value(); // TODO: handle dynamic shapes auto& literal = literals.emplace_back( - xla::ShapeUtil::DeviceShapeToHostShape(pjrt_data->shape())); + xla::ShapeUtil::DeviceShapeToHostShape(ifrt_data->shape())); std::vector byte_strides(literal.shape().dimensions_size()); XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal.shape(), absl::MakeSpan(byte_strides))); diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index b2be4cd91dd..f0ac786a9be 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -42,7 +42,7 @@ class IfrtComputationClient : public ComputationClient { absl::Span tensors) override; // Use XLA replication to re-assemble the sharded data. - DataPtr ReplicateShardedData(const DataPtr& handle); + // DataPtr ReplicateShardedData(const DataPtr& handle); std::vector TransferFromServer( absl::Span handles) override; @@ -178,6 +178,8 @@ class IfrtComputationClient : public ComputationClient { tsl::RCReference buffer; }; + tsl::RCReference ReplicateShardedData( + const std::shared_ptr handle); // struct PjRtShardedData : public Data { // PjRtShardedData(std::string device, xla::Shape shape) = delete; From f2f8334fa3cef405c870d0aafbd0d63d33a51c48 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 19 Oct 2023 18:38:58 +0000 Subject: [PATCH 11/33] fix output devices --- torch_xla/csrc/runtime/ifrt_computation_client.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index ae9fc5a9ed2..ec772b1b8b8 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -696,12 +696,13 @@ IfrtComputationClient::ExecuteReplicated( XLA_CHECK_EQ(output_shardings.size(), results.size()); for (int32_t i = 0; i < results.size(); ++i) { - xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); - XLA_CHECK(pjrt_device->IsAddressable()) - << pjrt_device->DebugString() << " is not addressable."; + // std::cout << "device str " << devices[i] << std::endl; + // xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); + // XLA_CHECK(pjrt_device->IsAddressable()) + // << pjrt_device->DebugString() << " is not addressable."; std::shared_ptr data = - std::make_shared(devices[i], results[i], output_shardings[i]); + std::make_shared("SPMD:0", results[i], output_shardings[i]); data_handles.push_back(data); } From c16ea6274b4400541cbc244a6522d6b843316f1d Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 19 Oct 2023 23:24:49 +0000 Subject: [PATCH 12/33] cleanup --- .../csrc/runtime/ifrt_computation_client.cc | 360 +++++------------- 1 file changed, 86 insertions(+), 274 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index ec772b1b8b8..3e0c3962741 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -73,15 +73,6 @@ std::unordered_map build_index_map( return device_index; } -// Builds the xla::Shape of the output xla::Literal on the host. -// xla::Shape host_output_shape(xla::PjRtBuffer* buffer) { -// xla::Shape shape = xla::ShapeUtil::MakeShape( -// buffer->element_type(), buffer->logical_dimensions().value()); -// *shape.mutable_layout() = buffer->layout(); - -// return xla::ShapeUtil::DeviceShapeToHostShape(shape); -// } - } // namespace std::string IfrtComputationClient::PjRtDeviceToString( @@ -187,36 +178,14 @@ ComputationClient::DataPtr IfrtComputationClient::GetDataShard( tsl::profiler::TraceMe activity("IfrtComputationClient::GetDataShard", tsl::profiler::TraceMeLevel::kInfo); return GetDataShards(data)[index]; - // if (PjRtShardedData* sharded_data = - // dynamic_cast(data.get())) { - // XLA_CHECK_LE(index, sharded_data->shards.size()) - // << "GetDataShard out of range with index: " << index - // << " and num of shard: " << sharded_data->shards.size(); - // std::shared_ptr shard = sharded_data->shards[index]; - // return std::make_shared(shard->device(), shard->shape(), - // shard->buffer); - // } else { - // return data; - // } } ComputationClient::DataPtr IfrtComputationClient::WrapDataShards( const std::vector& shards, std::string device, xla::Shape shape, xla::OpSharding sharding) { - // XLA_ERROR() << __FUNCTION__ << " not implemented"; - // std::vector> pjrt_data_shards; - // pjrt_data_shards.reserve(shards.size()); - // for (auto& shard : shards) { - // XLA_CHECK(shard != nullptr); - // auto pjrt_shard = dynamic_cast(shard.get()); - // pjrt_data_shards.push_back(std::make_shared( - // pjrt_shard->device(), pjrt_shard->shape(), pjrt_shard->buffer)); - // } - // return std::make_shared(device, shape, pjrt_data_shards, - // sharding); // TODO: implement CreateDataPlaceholder for sharded data if (shards.size() == 0) { - TF_LOG(WARNING) << "creating sharded placeholder"; + TF_LOG(INFO) << "creating sharded placeholder"; return std::make_shared(device, shape, tsl::RCReference(), sharding); } std::vector> arrays; @@ -301,15 +270,14 @@ std::vector IfrtComputationClient::TransferToServer( ComputationClient::DataPtr IfrtComputationClient::TransferShardsToServer( absl::Span tensor_shards, std::string device, xla::Shape shape, xla::OpSharding sharding) { - // XLA_ERROR() << __FUNCTION__ << " not implemented"; - // tsl::profiler::TraceMe activity( - // "IfrtComputationClient::TransferShardsToServer", - // tsl::profiler::TraceMeLevel::kInfo); - // // TODO(jonbolin): Consider using CopyToDevice when sharding is REPLICATED. - // // We are opting out of CopyToDevice for now due to the synchronization - // // issues observed in ShardingUtil::InputHandler, but because CopyToDevice - // // directly copies buffers between devices using ICI, it can be much faster - // // than transferring from the host to each device. + tsl::profiler::TraceMe activity( + "IfrtComputationClient::TransferShardsToServer", + tsl::profiler::TraceMeLevel::kInfo); + // TODO(jonbolin): Consider using CopyToDevice when sharding is REPLICATED. + // We are opting out of CopyToDevice for now due to the synchronization + // issues observed in ShardingUtil::InputHandler, but because CopyToDevice + // directly copies buffers between devices using ICI, it can be much faster + // than transferring from the host to each device. auto data_shards = TransferToServer(tensor_shards); std::vector> arrays; std::vector shard_shapes; @@ -340,22 +308,6 @@ ComputationClient::DataPtr IfrtComputationClient::TransferShardsToServer( ComputationClient::DataPtr IfrtComputationClient::CopyToDevice( ComputationClient::DataPtr data, std::string dst) { XLA_ERROR() << __FUNCTION__ << " not implemented"; - // tsl::profiler::TraceMe activity("IfrtComputationClient::CopyToDevice", - // tsl::profiler::TraceMeLevel::kInfo); - // const PjRtData* pjrt_data = dynamic_cast(data.get()); - // XLA_CHECK(pjrt_data->HasValue()) << "Can't copy invalid device data."; - - // xla::PjRtDevice* dst_device = StringToPjRtDevice(dst); - // XLA_CHECK(dst_device->IsAddressable()) << dst << "is not addressable."; - - // // Returns error if the buffer is already on `dst_device`. - // xla::StatusOr> status_or = - // pjrt_data->buffer->CopyToDevice(dst_device); - // XLA_CHECK(status_or.ok()) - // << pjrt_data->device() << " buffer already exists on " << dst; - - // return std::make_shared(dst, pjrt_data->shape(), - // std::move(status_or.value())); } tsl::RCReference IfrtComputationClient::ReplicateShardedData( @@ -366,8 +318,9 @@ tsl::RCReference IfrtComputationClient::ReplicateShardedData( } XLA_COUNTER("ReplicateShardedData", 1); - TF_VLOG(1) << "ReplicateShardedData (handle=" // TODO: why isn't GetHandle const? << handle->GetHandle() + TF_VLOG(1) << "ReplicateShardedData (handle=" << handle->GetHandle() << ", shape=" << handle->shape() << ")"; + // TODO: handle replicated data // if (sharded_data->GetSharding().type() == xla::OpSharding::REPLICATED) { // // Data is replicated, return the first shard // return sharded_data->shards[0]; @@ -404,40 +357,18 @@ tsl::RCReference IfrtComputationClient::ReplicateShardedData( std::shared_ptr> computations = Compile(std::move(instances)); - // auto shards = sharded_data->shards; XLA_CHECK_EQ(handle->buffer->sharding().devices().size(), GetLocalDevices().size()); - // auto device_index = build_index_map(GetLocalDevices()); - - // std::vector> arguments_by_device( - // GetLocalDevices().size(), - // std::vector(1)); - // for (auto shard : shards) { - // std::vector device_spec = - // absl::StrSplit(shard->device(), ':'); - // XLA_CHECK_EQ(device_spec.size(), 2) - // << "Invalid device specification: " << shard->device(); - // int device_i = device_index[std::stoi(device_spec[1])]; - // TF_VLOG(3) << shard->device() << " is mapped to local device index " - // << device_i; - // arguments_by_device[device_i][0] = shard; - // } + torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; // TODO: fix const plumbing for real DataPtr handle_but_not_const = std::make_shared(handle->device(), handle->buffer, handle->GetSharding()); - // std::vector> args; // TODO: figure out brace init = {{handle}}; - // args.push_back({}); - // args[0].push_back(handle_but_not_const); auto sharded_results = ExecuteReplicated(*computations.front(), {{handle_but_not_const}}, GetLocalDevices(), execute_options); auto replicated_output = std::dynamic_pointer_cast(sharded_results[0][0])->buffer->FullyReplicatedShard(xla::ifrt::ArrayCopySemantics::kAlwaysCopy); - // XLA_CHECK(sharded_results.size() > 0) - // << "empty ExecuteReplicated results returned."; - // XLA_CHECK(sharded_results[0].size() == 1) - // << "Wrong number of outputs, expected: 1, actual: " - // << sharded_results[0].size(); + // TODO: sanity check outputs return *replicated_output; } @@ -455,10 +386,6 @@ std::vector IfrtComputationClient::TransferFromServer( auto ifrt_data = std::dynamic_pointer_cast(handle); tsl::RCReference replicated_array = ReplicateShardedData(ifrt_data); - // TODO: this is probably wrong for MP - // auto replicated_array = pjrt_data->buffer->FullyReplicatedShard( - // xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value(); - // TODO: handle dynamic shapes auto& literal = literals.emplace_back( xla::ShapeUtil::DeviceShapeToHostShape(ifrt_data->shape())); @@ -563,87 +490,89 @@ IfrtComputationClient::ExecuteComputation( const ComputationClient::Computation& computation, absl::Span arguments, const std::string& device, const ExecuteComputationOptions& options) { - // Shared ownership of the timed section ensures that it will only get logged - // once both `ExecuteComputation` and the async work in `ExecuteSharded` are - // complete; a copy is held from the lambda that releases it when done. - auto timed = std::make_shared(ExecuteMetric()); - tsl::profiler::TraceMe activity("IfrtComputationClient::ExecuteComputation", - tsl::profiler::TraceMeLevel::kInfo); - TF_VLOG(1) << "Executing Ifrt computation on " << device; - const IfrtComputation& pjrt_computation = - dynamic_cast(computation); + // TODO: Implement sharded exec in IFRT + XLA_ERROR() << __FUNCTION__ << " not implemented"; + // // Shared ownership of the timed section ensures that it will only get logged + // // once both `ExecuteComputation` and the async work in `ExecuteSharded` are + // // complete; a copy is held from the lambda that releases it when done. + // auto timed = std::make_shared(ExecuteMetric()); + // tsl::profiler::TraceMe activity("IfrtComputationClient::ExecuteComputation", + // tsl::profiler::TraceMeLevel::kInfo); + // TF_VLOG(1) << "Executing Ifrt computation on " << device; + // const IfrtComputation& pjrt_computation = + // dynamic_cast(computation); + + // xla::PjRtDevice* pjrt_device = StringToPjRtDevice(device); + // XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); + + // std::vector> buffers; + // buffers.reserve(arguments.size()); + // for (auto& argument : arguments) { + // const IfrtData* pjrt_data = dynamic_cast(argument.get()); + + // // XLA_CHECK(pjrt_device == pjrt_data->buffer->device()) + // // << pjrt_device->DebugString() << " vs " + // // << pjrt_data->buffer->device()->DebugString(); + // buffers.push_back(pjrt_data->buffer); + // } - xla::PjRtDevice* pjrt_device = StringToPjRtDevice(device); - XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); + // xla::ExecuteOptions execute_options; + // execute_options.untuple_result = options.explode_tuple; + // execute_options.strict_shape_checking = false; - std::vector> buffers; - buffers.reserve(arguments.size()); - for (auto& argument : arguments) { - const IfrtData* pjrt_data = dynamic_cast(argument.get()); + // // Required as of cl/518733871 + // execute_options.use_major_to_minor_data_layout_for_callbacks = true; - // XLA_CHECK(pjrt_device == pjrt_data->buffer->device()) - // << pjrt_device->DebugString() << " vs " - // << pjrt_data->buffer->device()->DebugString(); - buffers.push_back(pjrt_data->buffer); - } + // xla::ifrt::DeviceList device_list({pjrt_device}); + // xla::ifrt::LoadedExecutable::ExecuteResult result = + // pjrt_computation.executable + // ->Execute(absl::MakeSpan(buffers), execute_options, device_list) + // .value(); - xla::ExecuteOptions execute_options; - execute_options.untuple_result = options.explode_tuple; - execute_options.strict_shape_checking = false; - - // Required as of cl/518733871 - execute_options.use_major_to_minor_data_layout_for_callbacks = true; + // xla::ifrt::Future returned_future = result.status; - xla::ifrt::DeviceList device_list({pjrt_device}); - xla::ifrt::LoadedExecutable::ExecuteResult result = - pjrt_computation.executable - ->Execute(absl::MakeSpan(buffers), execute_options, device_list) - .value(); + // auto results = result.outputs; + // std::vector datas; + // datas.reserve(results.size()); + // for (auto& result : results) { + // tsl::RCReference buffer = std::move(result); - xla::ifrt::Future returned_future = result.status; + // std::shared_ptr data = + // std::make_shared(device, std::move(buffer)); - auto results = result.outputs; - std::vector datas; - datas.reserve(results.size()); - for (auto& result : results) { - tsl::RCReference buffer = std::move(result); + // datas.push_back(data); + // } + // CreateDataHandlesCounter()->AddValue(datas.size()); - std::shared_ptr data = - std::make_shared(device, std::move(buffer)); + // auto mwait = std::make_shared(1); + // auto lockfn = [&, this, device, returned_future = std::move(returned_future), + // timed]() mutable { + // TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " + // << device; + // // Grab the shared lock and block the `WaitDeviceOps` until buffer is + // // ready. + // // TODO(JackCaoG): This lock should acquired outside of the lockfn and + // // passed in. It is possible that lockfn started after ExecuteComputation + // // released the xla_graph_executor lock, which will create a short windows + // // where device is unlcoked while execution is still running. + // auto lock = lock_device_shared(device); + // TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device + // << " Done"; + // // Signal that `ExecuteSharded` has completed for the ExecuteTime + // // metric. Copies the `timed` shared pointer into the lambda. + // XLA_CHECK(returned_future.IsValid()) + // << "returned_future in ExecuteComputation is empty"; + // returned_future.OnReady( + // [timed, lock = std::move(lock)](xla::Status unused) mutable { + // timed.reset(); + // TF_VLOG(3) << "ExecuteComputation returned_future->OnReady finished"; + // }); + // }; - datas.push_back(data); - } - CreateDataHandlesCounter()->AddValue(datas.size()); + // env::ScheduleIoClosure(util::MultiWait::Completer(mwait, std::move(lockfn))); - auto mwait = std::make_shared(1); - auto lockfn = [&, this, device, returned_future = std::move(returned_future), - timed]() mutable { - TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " - << device; - // Grab the shared lock and block the `WaitDeviceOps` until buffer is - // ready. - // TODO(JackCaoG): This lock should acquired outside of the lockfn and - // passed in. It is possible that lockfn started after ExecuteComputation - // released the xla_graph_executor lock, which will create a short windows - // where device is unlcoked while execution is still running. - auto lock = lock_device_shared(device); - TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device - << " Done"; - // Signal that `ExecuteSharded` has completed for the ExecuteTime - // metric. Copies the `timed` shared pointer into the lambda. - XLA_CHECK(returned_future.IsValid()) - << "returned_future in ExecuteComputation is empty"; - returned_future.OnReady( - [timed, lock = std::move(lock)](xla::Status unused) mutable { - timed.reset(); - TF_VLOG(3) << "ExecuteComputation returned_future->OnReady finished"; - }); - }; - - env::ScheduleIoClosure(util::MultiWait::Completer(mwait, std::move(lockfn))); - - TF_VLOG(1) << "Returning " << datas.size() << " results"; - return datas; + // TF_VLOG(1) << "Returning " << datas.size() << " results"; + // return datas; } std::vector> @@ -696,11 +625,6 @@ IfrtComputationClient::ExecuteReplicated( XLA_CHECK_EQ(output_shardings.size(), results.size()); for (int32_t i = 0; i < results.size(); ++i) { - // std::cout << "device str " << devices[i] << std::endl; - // xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); - // XLA_CHECK(pjrt_device->IsAddressable()) - // << pjrt_device->DebugString() << " is not addressable."; - std::shared_ptr data = std::make_shared("SPMD:0", results[i], output_shardings[i]); data_handles.push_back(data); @@ -708,118 +632,6 @@ IfrtComputationClient::ExecuteReplicated( // TODO: any useful debug logging return {data_handles}; - - // { - // tsl::profiler::TraceMe activity( - // "IfrtComputationClient::ExecuteReplicated_argument_handle", - // tsl::profiler::TraceMeLevel::kInfo); - // for (int32_t i = 0; i < devices.size(); ++i) { - // auto buffer_converter = [&, i]() { - // xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); - // XLA_CHECK(pjrt_device->IsAddressable()) << - // pjrt_device->DebugString(); - - // std::vector buffers; - // for (auto& argument : arguments[i]) { - // const PjRtData* pjrt_data = - // dynamic_cast(argument.get()); - - // XLA_CHECK(pjrt_device == pjrt_data->buffer->device()) - // << pjrt_device->DebugString() << " vs " - // << pjrt_data->buffer->device()->DebugString(); - // buffers.push_back(pjrt_data->buffer.get()); - // } - // argument_handles[i] = std::move(buffers); - // }; - // env::ScheduleIoClosure(util::MultiWait::Completer( - // mwait_argument, std::move(buffer_converter))); - // } - // mwait_argument->Wait(); - // } - - // xla::ExecuteOptions execute_options; - // execute_options.untuple_result = options.explode_tuple; - // execute_options.strict_shape_checking = true; - // // TODO(yeounoh) currently only support single-slice execution - // execute_options.multi_slice_config = nullptr; - - // // Required as of cl/518733871 - // execute_options.use_major_to_minor_data_layout_for_callbacks = true; - - // std::optional>> returned_futures( - // devices.size()); - // std::vector>> results; - // { - // tsl::profiler::TraceMe activity( - // "IfrtComputationClient::ExecuteReplicated_execute", - // tsl::profiler::TraceMeLevel::kInfo); - // results = pjrt_computation.executable - // ->Execute(std::move(argument_handles), execute_options, - // returned_futures) - // .value(); - // } - - // std::vector> data_handles; - // data_handles.reserve(results.size()); - // std::vector dims(results.size()); - - // { - // tsl::profiler::TraceMe activity( - // "IfrtComputationClient::ExecuteReplicated_result_handle", - // tsl::profiler::TraceMeLevel::kInfo); - // for (int32_t i = 0; i < results.size(); ++i) { - // xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); - // XLA_CHECK(pjrt_device->IsAddressable()) - // << pjrt_device->DebugString() << " is not addressable."; - - // std::vector datas; - // datas.reserve(results[i].size()); - // dims[i] = results[i].size(); - // for (int32_t j = 0; j < results[i].size(); ++j) { - // std::unique_ptr buffer = std::move(results[i][j]); - // XLA_CHECK(pjrt_device == buffer->device()) - // << "Exepcted device: " << pjrt_device->DebugString() - // << " vs. actual device: " << buffer->device()->DebugString(); - - // std::shared_ptr data = - // std::make_shared(devices[i], std::move(buffer)); - // datas.push_back(data); - // } - // data_handles.push_back(datas); - // } - // } - - // auto mwait = std::make_shared(1); - // auto lockfn = [&, this, returned_futures = std::move(*returned_futures), - // timed]() mutable { - // // Grab the shared lock and block the `WaitDeviceOps` until buffer is - // // ready. Since this is the SPMD code path. There is no points to grab - // // devices lock for every individual device. - // TF_VLOG(5) << "ExecuteReplicated acquiring PJRT device lock for " - // << spmd_device_str; - // auto lock = lock_device_shared(spmd_device_str); - // TF_VLOG(5) << "ExecuteReplicated acquiring PJRT device lock for " - // << spmd_device_str << " Done"; - // // Signal that `ExecuteReplicated` has completed for one of the devices - // // the ExecuteReplicatedTime metric. Here, we assume that all devices - // // will finish execution roughly at the same time, hence only use one of - // // the returned_futures. Copies the `timed` shared pointer into the - // // lambda. - // XLA_CHECK(returned_futures[0].IsValid()) - // << "returned_future in ExecuteReplicated is empty"; - // returned_futures[0].OnReady( - // [timed, lock = std::move(lock)](xla::Status unused) mutable { - // timed.reset(); - // TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady - // finished"; - // }); - // }; - // env::ScheduleIoClosure(util::MultiWait::Completer(mwait, - // std::move(lockfn))); - - // TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results " - // << "with dimensions [" << absl::StrJoin(dims, ",") << "]."; - // return data_handles; } size_t IfrtComputationClient::GetNumDevices() const { From 4d8418a41dfb0d04fd6ea1d22338cd58782bd159 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 28 Nov 2023 19:24:32 +0000 Subject: [PATCH 13/33] fix rebase issues --- torch_xla/csrc/runtime/BUILD | 31 ++-- .../csrc/runtime/ifrt_computation_client.cc | 138 +++++++----------- .../csrc/runtime/ifrt_computation_client.h | 35 +++-- torch_xla/csrc/runtime/initialize_pjrt.cc | 133 +++++++++++++++++ torch_xla/csrc/runtime/initialize_pjrt.h | 15 ++ .../csrc/runtime/pjrt_computation_client.cc | 112 +------------- 6 files changed, 247 insertions(+), 217 deletions(-) create mode 100644 torch_xla/csrc/runtime/initialize_pjrt.cc create mode 100644 torch_xla/csrc/runtime/initialize_pjrt.h diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 62d043308b0..998ba13e30e 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -83,19 +83,15 @@ cc_library( ":computation_client", ":debug_macros", ":env_vars", - ":multi_wait", + ":initialize_pjrt", + ":operation_manager", ":stablehlo_helper", ":tf_logging", - ":thread_pool", "@xla//xla:literal", "@xla//xla:shape_util", "@xla//xla/client:xla_computation", "@xla//xla/pjrt/distributed", - "@xla//xla/pjrt/gpu:se_gpu_pjrt_client", - "@xla//xla/service:gpu_plugin", "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt:tfrt_cpu_pjrt_client", - "@xla//xla/pjrt:pjrt_c_api_client", "@xla//xla/python/ifrt", "@xla//xla/python/pjrt_ifrt", "@tsl//tsl/profiler/lib:traceme", @@ -118,6 +114,7 @@ cc_library( ":debug_macros", ":env_hash", ":env_vars", + ":initialize_pjrt", ":operation_manager", ":profiler", ":stablehlo_helper", @@ -129,11 +126,7 @@ cc_library( "@xla//xla:shape_util", "@xla//xla/client:xla_computation", "@xla//xla/pjrt/distributed", - "@xla//xla/pjrt/gpu:se_gpu_pjrt_client", - "@xla//xla/service:gpu_plugin", "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt:tfrt_cpu_pjrt_client", - "@xla//xla/pjrt:pjrt_c_api_client", "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/platform/cloud:gcs_file_system", @@ -200,6 +193,24 @@ cc_test( ], ) +cc_library( + name = "initialize_pjrt", + srcs = ["initialize_pjrt.cc"], + hdrs = ["initialize_pjrt.h"], + deps = [ + ":debug_macros", + ":env_vars", + ":profiler", + ":sys_util", + ":tf_logging", + ":xla_coordinator", + "@xla//xla/service:gpu_plugin", + "@xla//xla/pjrt/gpu:se_gpu_pjrt_client", + "@xla//xla/pjrt:tfrt_cpu_pjrt_client", + "@xla//xla/pjrt:pjrt_c_api_client", + ], +) + cc_library( name = "metrics_analysis", srcs = ["metrics_analysis.cc"], diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 3e0c3962741..121ee80df1b 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -9,22 +9,19 @@ #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/env_vars.h" -#include "torch_xla/csrc/runtime/multi_wait.h" +#include "torch_xla/csrc/runtime/initialize_pjrt.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/tf_logging.h" -#include "torch_xla/csrc/runtime/thread_pool.h" +#include "torch_xla/csrc/runtime/xla_coordinator.h" #include "tsl/profiler/lib/traceme.h" #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/pjrt/distributed/distributed.h" -#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_api.h" -#include "xla/pjrt/pjrt_c_api_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/sharding.h" @@ -98,26 +95,7 @@ std::vector IfrtComputationClient::PjRtDevicesToString( IfrtComputationClient::IfrtComputationClient() { std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, ""); - if (device_type == "CPU") { - TF_VLOG(1) << "Initializing PjRt CPU client..."; - bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true); - int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1); - client_ = xla::ifrt::PjRtClient::Create( - std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value())); - } else if (device_type == "TPU" || device_type == "TPU_C_API") { - TF_VLOG(1) << "Initializing TFRT TPU client..."; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin( - "tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so"))); - tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); - XLA_CHECK(tpu_status.ok()); - client_ = xla::ifrt::PjRtClient::Create( - std::move(xla::GetCApiClient("TPU").value())); - } else { - XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, - device_type); - } - - XLA_CHECK(client_.get() != nullptr); + client_ = xla::ifrt::PjRtClient::Create(std::move(InitializePjRt(device_type))); // PjRtDevice IDs are not guaranteed to be dense, so we need to track // a device's global ordinal separately from its device ID. Order the @@ -130,10 +108,38 @@ IfrtComputationClient::IfrtComputationClient() { global_ordinals_[device->id()] = global_ordinals_.size(); std::string device_str = PjRtDeviceToString(device); string_to_device_.emplace(device_str, device); - device_locks_.emplace(device_str, std::make_unique()); } - // manually create the device_locks for SPMD device - device_locks_.emplace(spmd_device_str, std::make_unique()); + + auto tracked_devices = GetLocalDevices(); + tracked_devices.emplace_back(spmd_device_str); + operation_manager_ = std::move(OperationManager(std::move(tracked_devices))); +} + +IfrtComputationClient::~IfrtComputationClient() { + // In the GPU case, the PjRtClient depends on the DistributedRuntimeClient + // tracked in XlaCoordinator, so the PjRtClient must be destroyed first. + client_ = nullptr; + coordinator_ = nullptr; +} + +bool IfrtComputationClient::CoordinatorInitialized() const { + return coordinator_ != nullptr; +} + +void IfrtComputationClient::InitializeCoordinator(int global_rank, + int world_size, + std::string master_addr, + std::string port) { + XLA_CHECK(coordinator_ == nullptr) + << "Can only initialize the XlaCoordinator once."; + coordinator_ = std::make_unique(global_rank, world_size, + master_addr, port); +} + +XlaCoordinator& IfrtComputationClient::GetCoordinator() { + XLA_CHECK(coordinator_ != nullptr) + << "XlaCoordinator has not been initialized"; + return *coordinator_; } void IfrtComputationClient::IfrtData::Assign( @@ -222,7 +228,7 @@ std::optional IfrtComputationClient::GetDataSharding( } std::vector IfrtComputationClient::TransferToServer( - absl::Span tensors) { + absl::Span> tensors) { metrics::TimedSection timed(TransferToServerMetric()); tsl::profiler::TraceMe activity("IfrtComputationClient::TransferToServer", tsl::profiler::TraceMeLevel::kInfo); @@ -230,35 +236,27 @@ std::vector IfrtComputationClient::TransferToServer( datas.reserve(tensors.size()); int64_t total_size = 0; for (auto& tensor : tensors) { - xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor.device); + xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor->device()); - auto literal = std::make_shared(tensor.shape); - tensor.populate_fn(tensor, literal->untyped_data(), literal->size_bytes()); - std::vector byte_strides(literal->shape().dimensions_size()); - XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal->shape(), - absl::MakeSpan(byte_strides))); - total_size += literal->size_bytes(); + total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape()); - // Avoid use-after-free on `literal` due to unsequenced move and use. - xla::Literal* literal_pointer = literal.get(); tsl::RCReference buffer = client_ ->MakeArrayFromHostBuffer( - literal_pointer->untyped_data(), - xla::ifrt::ToDType(literal_pointer->shape().element_type()) - .value(), - xla::ifrt::Shape(literal_pointer->shape().dimensions()), - byte_strides, + tensor->data(), + xla::ifrt::ToDType(tensor->primitive_type()).value(), + xla::ifrt::Shape(tensor->dimensions()), + tensor->byte_strides(), // TODO: what is MemoryKind? xla::ifrt::SingleDeviceSharding::Create( pjrt_device, xla::ifrt::MemoryKind()), xla::PjRtClient::HostBufferSemantics:: kImmutableUntilTransferCompletes, - [literal{std::move(literal)}]() { /* frees literal */ }) + [tensor]() { /* frees tensor */ }) .value(); ComputationClient::DataPtr data = - std::make_shared(tensor.device, tensor.shape, buffer); + std::make_shared(tensor->device(), tensor->shape(), buffer); datas.push_back(data); } OutboundDataMetric()->AddSample(total_size); @@ -268,8 +266,8 @@ std::vector IfrtComputationClient::TransferToServer( } ComputationClient::DataPtr IfrtComputationClient::TransferShardsToServer( - absl::Span tensor_shards, std::string device, - xla::Shape shape, xla::OpSharding sharding) { + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) { tsl::profiler::TraceMe activity( "IfrtComputationClient::TransferShardsToServer", tsl::profiler::TraceMeLevel::kInfo); @@ -367,7 +365,7 @@ tsl::RCReference IfrtComputationClient::ReplicateShardedData( auto sharded_results = ExecuteReplicated(*computations.front(), {{handle_but_not_const}}, GetLocalDevices(), execute_options); - auto replicated_output = std::dynamic_pointer_cast(sharded_results[0][0])->buffer->FullyReplicatedShard(xla::ifrt::ArrayCopySemantics::kAlwaysCopy); + auto replicated_output = std::dynamic_pointer_cast(sharded_results[0])->buffer->FullyReplicatedShard(xla::ifrt::ArrayCopySemantics::kAlwaysCopy); // TODO: sanity check outputs return *replicated_output; } @@ -575,10 +573,10 @@ IfrtComputationClient::ExecuteComputation( // return datas; } -std::vector> +std::vector IfrtComputationClient::ExecuteReplicated( const ComputationClient::Computation& computation, - const std::vector>& arguments, + const absl::Span arguments, // TODO: devices isn't doing anything helpful here absl::Span devices, const ExecuteReplicatedOptions& options) { @@ -597,9 +595,9 @@ IfrtComputationClient::ExecuteReplicated( // << "ExecuteReplicated over " << devices.size() << " devices, but " // << arguments.size() << " arguments devices."; // TODO: parallelize again if necessary - std::vector> argument_handles(arguments[0].size()); - for (int32_t i = 0; i < arguments[0].size(); ++i) { - auto ifrt_data = std::dynamic_pointer_cast(arguments[0][i]); + std::vector> argument_handles(arguments.size()); + for (int32_t i = 0; i < arguments.size(); ++i) { + auto ifrt_data = std::dynamic_pointer_cast(arguments[i]); argument_handles[i] = ifrt_data->buffer; } @@ -683,37 +681,11 @@ xla::PjRtDevice* IfrtComputationClient::StringToPjRtDevice( return pjrt_device; } -std::shared_lock IfrtComputationClient::lock_device_shared( - const std::string& device) { - std::shared_lock lock(*device_locks_[device]); - return lock; -} - -std::unique_lock IfrtComputationClient::lock_device( - const std::string& device) { - std::unique_lock lock(*device_locks_[device]); - return lock; -} - void IfrtComputationClient::WaitDeviceOps( - const std::vector& devices) { - std::unordered_set wait_devices; - if (!devices.empty()) { - for (auto& device_str : devices) { - wait_devices.insert(device_str); - } - } else { - for (auto& device_str : GetLocalDevices()) { - wait_devices.insert(device_str); - } - } - for (const std::string& device_str : wait_devices) { - TF_VLOG(3) << "Waiting for device execution for " << device_str - << " to finish"; - lock_device(device_str); - TF_VLOG(3) << "Waiting for device execution for " << device_str - << " to finish.. Done"; - } + absl::Span devices) { + TF_VLOG(3) << "Waiting for " << absl::StrJoin(devices, ", "); + operation_manager_.WaitForDevices(devices.empty() ? GetLocalDevices() + : devices); } std::map IfrtComputationClient::GetMetrics() const { diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index f0ac786a9be..e7d5b8900c5 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -10,6 +10,7 @@ #include "absl/types/span.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/operation_manager.h" #include "torch_xla/csrc/runtime/util.h" #include "xla/client/xla_computation.h" #include "xla/literal.h" @@ -26,6 +27,7 @@ namespace runtime { class IfrtComputationClient : public ComputationClient { public: IfrtComputationClient(); + ~IfrtComputationClient(); DataPtr CreateDataPlaceholder(std::string device, xla::Shape shape) override; @@ -39,7 +41,7 @@ class IfrtComputationClient : public ComputationClient { std::optional GetDataSharding(DataPtr handle) override; std::vector TransferToServer( - absl::Span tensors) override; + absl::Span> tensors) override; // Use XLA replication to re-assemble the sharded data. // DataPtr ReplicateShardedData(const DataPtr& handle); @@ -47,9 +49,9 @@ class IfrtComputationClient : public ComputationClient { std::vector TransferFromServer( absl::Span handles) override; - DataPtr TransferShardsToServer(absl::Span tensor_shards, - std::string device, xla::Shape shape, - xla::OpSharding sharding) override; + DataPtr TransferShardsToServer( + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) override; DataPtr CopyToDevice(DataPtr data, std::string dst) override; @@ -61,9 +63,9 @@ class IfrtComputationClient : public ComputationClient { const std::string& device, const ExecuteComputationOptions& options) override; - std::vector> ExecuteReplicated( + std::vector ExecuteReplicated( const Computation& computation, - const std::vector>& arguments, + const absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) override; @@ -88,12 +90,18 @@ class IfrtComputationClient : public ComputationClient { std::shared_ptr> GetReplicationDevices() override; - void PrepareToExit() override { return; }; - - void WaitDeviceOps(const std::vector& devices) override; + void WaitDeviceOps(absl::Span devices) override; std::map GetMetrics() const override; + void InitializeCoordinator(int global_rank, int world_size, + std::string master_addr, + std::string port) override; + + XlaCoordinator& GetCoordinator() override; + + bool CoordinatorInitialized() const override; + // NOT IMPLEMENTED MemoryInfo GetMemoryInfo(const std::string& device) override { @@ -102,18 +110,17 @@ class IfrtComputationClient : public ComputationClient { private: std::shared_ptr client_; + std::unique_ptr coordinator_; // global_ordinals_ tracks a map from PjRtDeviceId to the device's // dense global ordinal. std::unordered_map global_ordinals_; std::unordered_map string_to_device_; std::shared_ptr> replication_devices_; - std::unordered_map> - device_locks_; + OperationManager operation_manager_; + tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool( + tsl::Env::Default(), "ifrt", std::thread::hardware_concurrency()); xla::PjRtDevice* StringToPjRtDevice(const std::string& device); - std::shared_lock lock_device_shared( - const std::string& device); - std::unique_lock lock_device(const std::string& device); std::string PjRtDeviceToString(xla::PjRtDevice* const device) const; std::vector PjRtDevicesToString( diff --git a/torch_xla/csrc/runtime/initialize_pjrt.cc b/torch_xla/csrc/runtime/initialize_pjrt.cc new file mode 100644 index 00000000000..160b8ae21d4 --- /dev/null +++ b/torch_xla/csrc/runtime/initialize_pjrt.cc @@ -0,0 +1,133 @@ +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/env_vars.h" +#include "torch_xla/csrc/runtime/profiler.h" +#include "torch_xla/csrc/runtime/sys_util.h" +#include "torch_xla/csrc/runtime/xla_coordinator.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/pjrt/pjrt_c_api_client.h" +#include "xla/pjrt/tfrt_cpu_pjrt_client.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" +#include "torch_xla/csrc/runtime/tf_logging.h" + +namespace torch_xla { +namespace runtime { + +namespace { + +xla::GpuAllocatorConfig GetGpuAllocatorConfig() { + auto allocator_config = xla::GpuAllocatorConfig{}; + if (sys_util::GetEnvString(env::kEnvPjrtAllocatorCudaAsync, "").empty() && + sys_util::GetEnvString(env::kEnvPjrtAllocatorPreallocate, "").empty() && + sys_util::GetEnvString(env::kEnvPjrtAllocatorFraction, "").empty()) { + return allocator_config; + } + if (sys_util::GetEnvBool(env::kEnvPjrtAllocatorCudaAsync, false)) { + allocator_config.kind = xla::GpuAllocatorConfig::Kind::kCudaAsync; + } + allocator_config.preallocate = + sys_util::GetEnvBool(env::kEnvPjrtAllocatorPreallocate, true); + allocator_config.memory_fraction = + sys_util::GetEnvDouble(env::kEnvPjrtAllocatorFraction, 0.75); + return allocator_config; +} + +} + +std::unique_ptr InitializePjRt(const std::string& device_type) { + std::unique_ptr client; + + if (device_type == "CPU") { + TF_VLOG(1) << "Initializing PjRt CPU client..."; + bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true); + int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1); + client = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value()); + } else if (device_type == "TPU" || device_type == "TPU_C_API") { + TF_VLOG(1) << "Initializing TFRT TPU client..."; + // Prefer $TPU_LIBRARY_PATH if set + auto tpu_library_path = sys_util::GetEnvString( + env::kEnvTpuLibraryPath, + sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so")); + XLA_CHECK_OK(pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status()); + tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); + XLA_CHECK_OK(tpu_status); + client = std::move(xla::GetCApiClient("TPU").value()); + const PJRT_Api* c_api = + static_cast(client.get())->pjrt_c_api(); + profiler::RegisterProfilerForPlugin(c_api); + } else if (device_type == "TPU_LEGACY") { + XLA_ERROR() << "TPU_LEGACY client is no longer available."; + } else if (device_type == "GPU" || device_type == "CUDA" || + device_type == "ROCM") { + TF_VLOG(1) << "Initializing PjRt GPU client..."; + bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true); + int local_process_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0); + int global_process_rank = sys_util::GetEnvInt("RANK", local_process_rank); + int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1); + int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size); + std::string master_addr = + runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); + std::string port = runtime::sys_util::GetEnvString( + "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); + + xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; + xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; + auto allowed_devices = + std::make_optional>(std::set{local_process_rank}); + if (global_world_size > 1) { + // Use the XlaCoordinator as the distributed key-value store. + coordinator_ = std::make_unique( + global_process_rank, global_world_size, master_addr, port); + std::shared_ptr distributed_client = + coordinator_->GetClient(); + std::string key_prefix = "gpu:"; + kv_get = [distributed_client, key_prefix]( + std::string_view k, + absl::Duration timeout) -> xla::StatusOr { + return distributed_client->BlockingKeyValueGet( + absl::StrCat(key_prefix, k), timeout); + }; + kv_put = [distributed_client, key_prefix]( + std::string_view k, std::string_view v) -> xla::Status { + return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); + }; + } + TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" + << global_process_rank << ", num_nodes=" << global_world_size; + xla::GpuClientOptions options; + options.allocator_config = GetGpuAllocatorConfig(); + options.node_id = global_process_rank; + options.num_nodes = global_world_size; + options.allowed_devices = allowed_devices; + options.platform_name = "gpu"; + options.should_stage_host_to_device_transfers = true; + options.kv_get = kv_get; + options.kv_put = kv_put; + client = std::move(xla::GetStreamExecutorGpuClient(options).value()); + } else if (device_type == "XPU") { + TF_VLOG(1) << "Initializing PjRt XPU client..."; + XLA_CHECK_OK( + pjrt::LoadPjrtPlugin( + "xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so")) + .status()); + client = std::move(xla::GetCApiClient("XPU").value()); + } else if (device_type == "NEURON") { + TF_VLOG(1) << "Initializing PjRt NEURON client..."; + XLA_CHECK_OK(pjrt::LoadPjrtPlugin("NEURON", sys_util::GetEnvString( + env::kEnvNeuronLibraryPath, + "libneuronpjrt.so")) + .status()); + client = std::move(xla::GetCApiClient("NEURON").value()); + } else { + XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, + device_type); + } + + XLA_CHECK(client.get() != nullptr); + + return std::move(client); +} + +} +} diff --git a/torch_xla/csrc/runtime/initialize_pjrt.h b/torch_xla/csrc/runtime/initialize_pjrt.h new file mode 100644 index 00000000000..395deac1182 --- /dev/null +++ b/torch_xla/csrc/runtime/initialize_pjrt.h @@ -0,0 +1,15 @@ +#ifndef XLA_CLIENT_INITIALIZE_PJRT_CLIENT_H_ +#define XLA_CLIENT_INITIALIZE_PJRT_CLIENT_H_ + + +#include "xla/pjrt/pjrt_client.h" + +namespace torch_xla { +namespace runtime { + +std::unique_ptr InitializePjRt(const std::string& device_type); + +} +} + +#endif // XLA_CLIENT_INITIALIZE_PJRT_H_ diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 3796b3e41ab..22cbcfa5b57 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -13,6 +13,7 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/env_hash.h" #include "torch_xla/csrc/runtime/env_vars.h" +#include "torch_xla/csrc/runtime/initialize_pjrt.h" #include "torch_xla/csrc/runtime/operation_manager.h" #include "torch_xla/csrc/runtime/profiler.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" @@ -25,14 +26,8 @@ #include "xla/client/xla_computation.h" #include "xla/layout_util.h" #include "xla/literal.h" -#include "xla/pjrt/c/pjrt_c_api.h" -#include "xla/pjrt/distributed/distributed.h" -#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" -#include "xla/pjrt/pjrt_api.h" -#include "xla/pjrt/pjrt_c_api_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" #include "xla/shape.h" using xla::internal::XlaBuilderFriend; @@ -68,23 +63,6 @@ xla::Shape host_output_shape(xla::PjRtBuffer* buffer) { return xla::ShapeUtil::DeviceShapeToHostShape(shape); } -xla::GpuAllocatorConfig GetGpuAllocatorConfig() { - auto allocator_config = xla::GpuAllocatorConfig{}; - if (sys_util::GetEnvString(env::kEnvPjrtAllocatorCudaAsync, "").empty() && - sys_util::GetEnvString(env::kEnvPjrtAllocatorPreallocate, "").empty() && - sys_util::GetEnvString(env::kEnvPjrtAllocatorFraction, "").empty()) { - return allocator_config; - } - if (sys_util::GetEnvBool(env::kEnvPjrtAllocatorCudaAsync, false)) { - allocator_config.kind = xla::GpuAllocatorConfig::Kind::kCudaAsync; - } - allocator_config.preallocate = - sys_util::GetEnvBool(env::kEnvPjrtAllocatorPreallocate, true); - allocator_config.memory_fraction = - sys_util::GetEnvDouble(env::kEnvPjrtAllocatorFraction, 0.75); - return allocator_config; -} - torch::lazy::hash_t hash_comp_env( std::shared_ptr client, std::vector& ordered_devices) { @@ -142,93 +120,7 @@ std::vector PjRtComputationClient::PjRtDevicesToString( PjRtComputationClient::PjRtComputationClient() { std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, ""); - if (device_type == "CPU") { - TF_VLOG(1) << "Initializing PjRt CPU client..."; - bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true); - int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1); - client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value()); - } else if (device_type == "TPU" || device_type == "TPU_C_API") { - TF_VLOG(1) << "Initializing TFRT TPU client..."; - // Prefer $TPU_LIBRARY_PATH if set - auto tpu_library_path = sys_util::GetEnvString( - env::kEnvTpuLibraryPath, - sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so")); - XLA_CHECK_OK(pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status()); - tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); - XLA_CHECK_OK(tpu_status); - client_ = std::move(xla::GetCApiClient("TPU").value()); - const PJRT_Api* c_api = - static_cast(client_.get())->pjrt_c_api(); - profiler::RegisterProfilerForPlugin(c_api); - } else if (device_type == "TPU_LEGACY") { - XLA_ERROR() << "TPU_LEGACY client is no longer available."; - } else if (device_type == "GPU" || device_type == "CUDA" || - device_type == "ROCM") { - TF_VLOG(1) << "Initializing PjRt GPU client..."; - bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true); - int local_process_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0); - int global_process_rank = sys_util::GetEnvInt("RANK", local_process_rank); - int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1); - int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size); - std::string master_addr = - runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); - std::string port = runtime::sys_util::GetEnvString( - "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); - - xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; - xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; - auto allowed_devices = - std::make_optional>(std::set{local_process_rank}); - if (global_world_size > 1) { - // Use the XlaCoordinator as the distributed key-value store. - coordinator_ = std::make_unique( - global_process_rank, global_world_size, master_addr, port); - std::shared_ptr distributed_client = - coordinator_->GetClient(); - std::string key_prefix = "gpu:"; - kv_get = [distributed_client, key_prefix]( - std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - return distributed_client->BlockingKeyValueGet( - absl::StrCat(key_prefix, k), timeout); - }; - kv_put = [distributed_client, key_prefix]( - std::string_view k, std::string_view v) -> xla::Status { - return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); - }; - } - TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" - << global_process_rank << ", num_nodes=" << global_world_size; - xla::GpuClientOptions options; - options.allocator_config = GetGpuAllocatorConfig(); - options.node_id = global_process_rank; - options.num_nodes = global_world_size; - options.allowed_devices = allowed_devices; - options.platform_name = "gpu"; - options.should_stage_host_to_device_transfers = true; - options.kv_get = kv_get; - options.kv_put = kv_put; - client_ = std::move(xla::GetStreamExecutorGpuClient(options).value()); - } else if (device_type == "XPU") { - TF_VLOG(1) << "Initializing PjRt XPU client..."; - XLA_CHECK_OK( - pjrt::LoadPjrtPlugin( - "xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so")) - .status()); - client_ = std::move(xla::GetCApiClient("XPU").value()); - } else if (device_type == "NEURON") { - TF_VLOG(1) << "Initializing PjRt NEURON client..."; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin("NEURON", sys_util::GetEnvString( - env::kEnvNeuronLibraryPath, - "libneuronpjrt.so")) - .status()); - client_ = std::move(xla::GetCApiClient("NEURON").value()); - } else { - XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, - device_type); - } - - XLA_CHECK(client_.get() != nullptr); + client_ = std::move(InitializePjRt(device_type)); // PjRtDevice IDs are not guaranteed to be dense, so we need to track // a device's global ordinal separately from its device ID. Order the From e5a13a28a975723e18dc87d016a6124075adaf00 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 28 Nov 2023 19:55:31 +0000 Subject: [PATCH 14/33] fix const plumbing --- torch_xla/csrc/runtime/ifrt_computation_client.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 121ee80df1b..dc919e9af81 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -360,10 +360,8 @@ tsl::RCReference IfrtComputationClient::ReplicateShardedData( torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; - // TODO: fix const plumbing for real - DataPtr handle_but_not_const = std::make_shared(handle->device(), handle->buffer, handle->GetSharding()); auto sharded_results = - ExecuteReplicated(*computations.front(), {{handle_but_not_const}}, + ExecuteReplicated(*computations.front(), {{handle}}, GetLocalDevices(), execute_options); auto replicated_output = std::dynamic_pointer_cast(sharded_results[0])->buffer->FullyReplicatedShard(xla::ifrt::ArrayCopySemantics::kAlwaysCopy); // TODO: sanity check outputs From 155a811f4b5dfe262f4e8b45a413ec875374ff87 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 28 Nov 2023 19:58:14 +0000 Subject: [PATCH 15/33] shared_ptr to unique_ptr --- torch_xla/csrc/runtime/pjrt_computation_client.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index a54b3c0b3b1..b43ffdecb45 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -110,7 +110,7 @@ class PjRtComputationClient : public ComputationClient { }; private: - std::shared_ptr client_; + std::unique_ptr client_; std::unique_ptr coordinator_; // global_ordinals_ tracks a map from PjRtDeviceId to the device's // dense global ordinal. From b1238daade092e69e6baaabbdfc30287de6ec45e Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 29 Nov 2023 22:03:08 +0000 Subject: [PATCH 16/33] parallelize input/output handling --- .../csrc/runtime/ifrt_computation_client.cc | 44 +++++++++---------- .../csrc/runtime/ifrt_computation_client.h | 5 ++- .../csrc/runtime/pjrt_computation_client.cc | 2 +- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index dc919e9af81..8d2826478e1 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -37,7 +37,7 @@ namespace runtime { namespace { -static std::string spmd_device_str = "SPMD:0"; +static const std::string spmd_device_str = "SPMD:0"; // Initializes a distributed runtime client if dist_service_addr is specified std::shared_ptr @@ -589,15 +589,15 @@ IfrtComputationClient::ExecuteReplicated( tsl::profiler::TraceMeLevel::kInfo); const IfrtComputation& ifrt_computation = dynamic_cast(computation); - // XLA_CHECK(devices.size() == arguments.size()) - // << "ExecuteReplicated over " << devices.size() << " devices, but " - // << arguments.size() << " arguments devices."; - // TODO: parallelize again if necessary + std::vector> argument_handles(arguments.size()); - for (int32_t i = 0; i < arguments.size(); ++i) { - auto ifrt_data = std::dynamic_pointer_cast(arguments[i]); - argument_handles[i] = ifrt_data->buffer; - } + pool_.ParallelFor(arguments.size(), 10000, + [&](int64_t start, int64_t end) { + for (int32_t i = start; i < end; ++i) { + auto ifrt_data = std::dynamic_pointer_cast(arguments[i]); + argument_handles[i] = ifrt_data->buffer; + } + }); xla::ExecuteOptions execute_options; execute_options.untuple_result = options.explode_tuple; @@ -613,21 +613,21 @@ IfrtComputationClient::ExecuteReplicated( xla::ifrt::Future returned_future = result.status; auto results = result.outputs; - std::vector data_handles; - data_handles.reserve(results.size()); - - XLA_CHECK(ifrt_computation.executable->GetOutputShardings().has_value()); - auto output_shardings = *(ifrt_computation.executable->GetOutputShardings()); + XLA_CHECK(ifrt_computation.output_shardings_.has_value()); + auto& output_shardings = *(ifrt_computation.output_shardings_); XLA_CHECK_EQ(output_shardings.size(), results.size()); - for (int32_t i = 0; i < results.size(); ++i) { - std::shared_ptr data = - std::make_shared("SPMD:0", results[i], output_shardings[i]); - data_handles.push_back(data); - } - - // TODO: any useful debug logging - return {data_handles}; + std::vector data_handles(results.size()); + pool_.ParallelFor(results.size(), 10000, + [&](int64_t start, int64_t end) { + for (int32_t i = start; i < end; ++i) { + data_handles[i] = + std::make_shared(spmd_device_str, results[i], output_shardings[i]); + } + }); + + TF_VLOG(1) << "Returning " << data_handles.size() << " sharded outputs."; + return data_handles; } size_t IfrtComputationClient::GetNumDevices() const { diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index e7d5b8900c5..091a5ee4920 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -247,9 +247,12 @@ class IfrtComputationClient : public ComputationClient { std::vector devices, std::unique_ptr executable) : Computation(std::move(computation), std::move(devices)), - executable(std::move(executable)) {} + executable(std::move(executable)) { + output_shardings_ = this->executable->GetOutputShardings(); + } std::unique_ptr executable; + std::optional> output_shardings_; }; }; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 22cbcfa5b57..aeb43637742 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -37,7 +37,7 @@ namespace runtime { namespace { -static std::string spmd_device_str = "SPMD:0"; +static const std::string spmd_device_str = "SPMD:0"; // Builds a map from the device's global ordinal to its index in the `devices` // array. From 472b9d4360987b77b0c2dabcea6f0bfe687ff993 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 29 Nov 2023 22:21:31 +0000 Subject: [PATCH 17/33] fix concurrency issues --- .../csrc/runtime/ifrt_computation_client.cc | 63 ++++++++++++------- 1 file changed, 42 insertions(+), 21 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 8d2826478e1..daa9ea8bb32 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -5,6 +5,7 @@ #include #include "absl/strings/ascii.h" +#include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" @@ -578,11 +579,9 @@ IfrtComputationClient::ExecuteReplicated( // TODO: devices isn't doing anything helpful here absl::Span devices, const ExecuteReplicatedOptions& options) { - // XLA_ERROR() << __FUNCTION__ << " not implemented"; // Shared ownership of the timed section ensures that it will only get logged // once both `ExecuteReplicated` and the async work in `Execute` are // complete; a copy is held from the lambda that releases it when done. - // TODO: fix timing auto timed = std::make_shared(ExecuteReplicatedMetric()); tsl::profiler::TraceMe activity("IfrtComputationClient::ExecuteReplicated", @@ -591,13 +590,18 @@ IfrtComputationClient::ExecuteReplicated( dynamic_cast(computation); std::vector> argument_handles(arguments.size()); - pool_.ParallelFor(arguments.size(), 10000, - [&](int64_t start, int64_t end) { - for (int32_t i = start; i < end; ++i) { - auto ifrt_data = std::dynamic_pointer_cast(arguments[i]); - argument_handles[i] = ifrt_data->buffer; - } - }); + { + absl::BlockingCounter counter(arguments.size()); + pool_.ParallelFor(arguments.size(), 10000, + [&](int64_t start, int64_t end) { + for (int32_t i = start; i < end; ++i) { + auto ifrt_data = std::dynamic_pointer_cast(arguments[i]); + argument_handles[i] = ifrt_data->buffer; + counter.DecrementCount(); + } + }); + counter.Wait(); + } xla::ExecuteOptions execute_options; execute_options.untuple_result = options.explode_tuple; @@ -605,26 +609,43 @@ IfrtComputationClient::ExecuteReplicated( // TODO(yeounoh) currently only support single-slice execution execute_options.multi_slice_config = nullptr; + TF_VLOG(5) << "ExecuteReplicated acquiring IFRT device lock for " + << spmd_device_str; + auto op_tracker = operation_manager_.StartOperation(spmd_device_str); + TF_VLOG(5) << "ExecuteReplicated acquiring IFRT device lock for " + << spmd_device_str << " Done"; + xla::ifrt::LoadedExecutable::ExecuteResult result = ifrt_computation.executable ->Execute(absl::MakeSpan(argument_handles), execute_options, std::nullopt) .value(); - xla::ifrt::Future returned_future = result.status; - auto results = result.outputs; + result.status.OnReady( + std::move([timed, op_tracker = std::move(op_tracker)]( + xla::Status status) mutable { + timed.reset(); + TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady finished with status " << status; + })); + + auto outputs = result.outputs; XLA_CHECK(ifrt_computation.output_shardings_.has_value()); auto& output_shardings = *(ifrt_computation.output_shardings_); - XLA_CHECK_EQ(output_shardings.size(), results.size()); - - std::vector data_handles(results.size()); - pool_.ParallelFor(results.size(), 10000, - [&](int64_t start, int64_t end) { - for (int32_t i = start; i < end; ++i) { - data_handles[i] = - std::make_shared(spmd_device_str, results[i], output_shardings[i]); - } - }); + XLA_CHECK_EQ(output_shardings.size(), outputs.size()); + + std::vector data_handles(outputs.size()); + { + absl::BlockingCounter counter(outputs.size()); + pool_.ParallelFor(outputs.size(), 10000, + [&](int64_t start, int64_t end) { + for (int32_t i = start; i < end; ++i) { + data_handles[i] = + std::make_shared(spmd_device_str, outputs[i], output_shardings[i]); + counter.DecrementCount(); + } + }); + counter.Wait(); + } TF_VLOG(1) << "Returning " << data_handles.size() << " sharded outputs."; return data_handles; From a6544b711e5889c6f98557031a620d2f69c6e1a1 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 29 Nov 2023 22:24:28 +0000 Subject: [PATCH 18/33] remove some commented out code --- .../csrc/runtime/ifrt_computation_client.cc | 85 ------------------- .../csrc/runtime/ifrt_computation_client.h | 54 ------------ 2 files changed, 139 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index daa9ea8bb32..cea7d3c26dc 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -320,10 +320,6 @@ tsl::RCReference IfrtComputationClient::ReplicateShardedData( TF_VLOG(1) << "ReplicateShardedData (handle=" << handle->GetHandle() << ", shape=" << handle->shape() << ")"; // TODO: handle replicated data - // if (sharded_data->GetSharding().type() == xla::OpSharding::REPLICATED) { - // // Data is replicated, return the first shard - // return sharded_data->shards[0]; - // } xla::XlaBuilder builder("ReplicateShardedData"); xla::Shape shape = handle->shape(); builder.SetSharding(handle->GetSharding()); @@ -489,87 +485,6 @@ IfrtComputationClient::ExecuteComputation( const std::string& device, const ExecuteComputationOptions& options) { // TODO: Implement sharded exec in IFRT XLA_ERROR() << __FUNCTION__ << " not implemented"; - // // Shared ownership of the timed section ensures that it will only get logged - // // once both `ExecuteComputation` and the async work in `ExecuteSharded` are - // // complete; a copy is held from the lambda that releases it when done. - // auto timed = std::make_shared(ExecuteMetric()); - // tsl::profiler::TraceMe activity("IfrtComputationClient::ExecuteComputation", - // tsl::profiler::TraceMeLevel::kInfo); - // TF_VLOG(1) << "Executing Ifrt computation on " << device; - // const IfrtComputation& pjrt_computation = - // dynamic_cast(computation); - - // xla::PjRtDevice* pjrt_device = StringToPjRtDevice(device); - // XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); - - // std::vector> buffers; - // buffers.reserve(arguments.size()); - // for (auto& argument : arguments) { - // const IfrtData* pjrt_data = dynamic_cast(argument.get()); - - // // XLA_CHECK(pjrt_device == pjrt_data->buffer->device()) - // // << pjrt_device->DebugString() << " vs " - // // << pjrt_data->buffer->device()->DebugString(); - // buffers.push_back(pjrt_data->buffer); - // } - - // xla::ExecuteOptions execute_options; - // execute_options.untuple_result = options.explode_tuple; - // execute_options.strict_shape_checking = false; - - // // Required as of cl/518733871 - // execute_options.use_major_to_minor_data_layout_for_callbacks = true; - - // xla::ifrt::DeviceList device_list({pjrt_device}); - // xla::ifrt::LoadedExecutable::ExecuteResult result = - // pjrt_computation.executable - // ->Execute(absl::MakeSpan(buffers), execute_options, device_list) - // .value(); - - // xla::ifrt::Future returned_future = result.status; - - // auto results = result.outputs; - // std::vector datas; - // datas.reserve(results.size()); - // for (auto& result : results) { - // tsl::RCReference buffer = std::move(result); - - // std::shared_ptr data = - // std::make_shared(device, std::move(buffer)); - - // datas.push_back(data); - // } - // CreateDataHandlesCounter()->AddValue(datas.size()); - - // auto mwait = std::make_shared(1); - // auto lockfn = [&, this, device, returned_future = std::move(returned_future), - // timed]() mutable { - // TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " - // << device; - // // Grab the shared lock and block the `WaitDeviceOps` until buffer is - // // ready. - // // TODO(JackCaoG): This lock should acquired outside of the lockfn and - // // passed in. It is possible that lockfn started after ExecuteComputation - // // released the xla_graph_executor lock, which will create a short windows - // // where device is unlcoked while execution is still running. - // auto lock = lock_device_shared(device); - // TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device - // << " Done"; - // // Signal that `ExecuteSharded` has completed for the ExecuteTime - // // metric. Copies the `timed` shared pointer into the lambda. - // XLA_CHECK(returned_future.IsValid()) - // << "returned_future in ExecuteComputation is empty"; - // returned_future.OnReady( - // [timed, lock = std::move(lock)](xla::Status unused) mutable { - // timed.reset(); - // TF_VLOG(3) << "ExecuteComputation returned_future->OnReady finished"; - // }); - // }; - - // env::ScheduleIoClosure(util::MultiWait::Completer(mwait, std::move(lockfn))); - - // TF_VLOG(1) << "Returning " << datas.size() << " results"; - // return datas; } std::vector diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 091a5ee4920..468c1050783 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -187,60 +187,6 @@ class IfrtComputationClient : public ComputationClient { tsl::RCReference ReplicateShardedData( const std::shared_ptr handle); - // struct PjRtShardedData : public Data { - // PjRtShardedData(std::string device, xla::Shape shape) = delete; - - // PjRtShardedData(std::string device, xla::Shape shape, - // std::vector> shards, - // xla::OpSharding sharding) - // : Data(std::move(device), std::move(shape)), - // shards(shards), - // sharding(sharding) {} - - // Handle GetHandle() override { - // // Always returns `Handle` of the first shard. - // return shards[0]->GetHandle(); - // } - - // void Assign(const torch::lazy::BackendData& data) override { - // const PjRtShardedData& pjrt_sharded_data = - // dynamic_cast(data); - // if (&pjrt_sharded_data != this) { - // shards = std::move(pjrt_sharded_data.shards); - // } - // } - - // bool HasValue() const override { - // if (shards.empty()) { - // return false; - // } - - // for (auto& shard : shards) { - // if (!shard->HasValue()) { - // return false; - // } - // } - // return true; - // } - - // std::string ToString() const override { - // std::stringstream ss; - // ss << "XLAShardedData: \n"; - // ss << " Data Device: " << device() << "\n"; - // ss << " Data Shape: " << shape().ToString() << "\n"; - // ss << " OpSharding: " - // << xla::HloSharding::FromProto(sharding)->ToString() << "\n"; - // ss << " NumShards: " << shards.size() << "\n"; - // return ss.str(); - // } - - // bool HasSharding() const override { return true; } - - // xla::OpSharding GetSharding() const override { return sharding; } - - // std::vector> shards; - // xla::OpSharding sharding; - // }; struct IfrtComputation : public Computation { IfrtComputation(xla::XlaComputation computation, From 5f6b4afe63ef833de4086f9ed13340a5b4b22e73 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 1 Dec 2023 00:09:39 +0000 Subject: [PATCH 19/33] formatting --- .../csrc/runtime/ifrt_computation_client.cc | 156 ++++++++++-------- .../csrc/runtime/ifrt_computation_client.h | 15 +- torch_xla/csrc/runtime/initialize_pjrt.cc | 17 +- torch_xla/csrc/runtime/initialize_pjrt.h | 5 +- 4 files changed, 104 insertions(+), 89 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index cea7d3c26dc..5a6dac6c873 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -96,7 +96,8 @@ std::vector IfrtComputationClient::PjRtDevicesToString( IfrtComputationClient::IfrtComputationClient() { std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, ""); - client_ = xla::ifrt::PjRtClient::Create(std::move(InitializePjRt(device_type))); + client_ = + xla::ifrt::PjRtClient::Create(std::move(InitializePjRt(device_type))); // PjRtDevice IDs are not guaranteed to be dense, so we need to track // a device's global ordinal separately from its device ID. Order the @@ -168,7 +169,11 @@ std::vector IfrtComputationClient::GetDataShards( std::vector shards; if (data->HasSharding()) { auto ifrt_data = std::dynamic_pointer_cast(data); - std::vector> arrays = ifrt_data->buffer->DisassembleIntoSingleDeviceArrays(xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value(); + std::vector> arrays = + ifrt_data->buffer + ->DisassembleIntoSingleDeviceArrays( + xla::ifrt::ArrayCopySemantics::kAlwaysCopy) + .value(); for (auto array : arrays) { shards.push_back(std::make_shared( @@ -193,7 +198,8 @@ ComputationClient::DataPtr IfrtComputationClient::WrapDataShards( // TODO: implement CreateDataPlaceholder for sharded data if (shards.size() == 0) { TF_LOG(INFO) << "creating sharded placeholder"; - return std::make_shared(device, shape, tsl::RCReference(), sharding); + return std::make_shared( + device, shape, tsl::RCReference(), sharding); } std::vector> arrays; std::vector shard_shapes; @@ -203,22 +209,25 @@ ComputationClient::DataPtr IfrtComputationClient::WrapDataShards( shard_shapes.push_back(ifrt_shard->buffer->shape()); } xla::ifrt::Shape ifrt_shape(shape.dimensions()); - xla::ifrt::DeviceList devices_list({client_->addressable_devices().begin(), client_->addressable_devices().end()}); + xla::ifrt::DeviceList devices_list({client_->addressable_devices().begin(), + client_->addressable_devices().end()}); XLA_CHECK_EQ(shard_shapes.size(), devices_list.size()); - std::unique_ptr ifrt_sharding = xla::ifrt::ConcreteSharding::Create( - devices_list, - xla::ifrt::MemoryKind(), - ifrt_shape, - shard_shapes - ); + std::unique_ptr ifrt_sharding = + xla::ifrt::ConcreteSharding::Create(devices_list, xla::ifrt::MemoryKind(), + ifrt_shape, shard_shapes); // TODO: Attach HloSharding instead when it is supported - // std::unique_ptr ifrt_sharding = xla::ifrt::HloSharding::Create( + // std::unique_ptr ifrt_sharding = + // xla::ifrt::HloSharding::Create( // devices_list, // xla::ifrt::MemoryKind(), // xla::HloSharding::FromProto(sharding).value() // ); - tsl::RCReference sharded_array = client_->AssembleArrayFromSingleDeviceArrays( - ifrt_shape, std::move(ifrt_sharding), absl::MakeSpan(arrays), xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value(); + tsl::RCReference sharded_array = + client_ + ->AssembleArrayFromSingleDeviceArrays( + ifrt_shape, std::move(ifrt_sharding), absl::MakeSpan(arrays), + xla::ifrt::ArrayCopySemantics::kAlwaysCopy) + .value(); return std::make_shared(device, shape, sharded_array, sharding); } @@ -229,7 +238,7 @@ std::optional IfrtComputationClient::GetDataSharding( } std::vector IfrtComputationClient::TransferToServer( - absl::Span> tensors) { + absl::Span> tensors) { metrics::TimedSection timed(TransferToServerMetric()); tsl::profiler::TraceMe activity("IfrtComputationClient::TransferToServer", tsl::profiler::TraceMeLevel::kInfo); @@ -246,8 +255,7 @@ std::vector IfrtComputationClient::TransferToServer( ->MakeArrayFromHostBuffer( tensor->data(), xla::ifrt::ToDType(tensor->primitive_type()).value(), - xla::ifrt::Shape(tensor->dimensions()), - tensor->byte_strides(), + xla::ifrt::Shape(tensor->dimensions()), tensor->byte_strides(), // TODO: what is MemoryKind? xla::ifrt::SingleDeviceSharding::Create( pjrt_device, xla::ifrt::MemoryKind()), @@ -267,8 +275,8 @@ std::vector IfrtComputationClient::TransferToServer( } ComputationClient::DataPtr IfrtComputationClient::TransferShardsToServer( - absl::Span> tensor_shards, - std::string device, xla::Shape shape, xla::OpSharding sharding) { + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) { tsl::profiler::TraceMe activity( "IfrtComputationClient::TransferShardsToServer", tsl::profiler::TraceMeLevel::kInfo); @@ -286,21 +294,24 @@ ComputationClient::DataPtr IfrtComputationClient::TransferShardsToServer( shard_shapes.push_back(ifrt_shard->buffer->shape()); } xla::ifrt::Shape ifrt_shape(shape.dimensions()); - xla::ifrt::DeviceList devices_list({client_->addressable_devices().begin(), client_->addressable_devices().end()}); - std::unique_ptr ifrt_sharding = xla::ifrt::ConcreteSharding::Create( - devices_list, - xla::ifrt::MemoryKind(), - ifrt_shape, - shard_shapes - ); + xla::ifrt::DeviceList devices_list({client_->addressable_devices().begin(), + client_->addressable_devices().end()}); + std::unique_ptr ifrt_sharding = + xla::ifrt::ConcreteSharding::Create(devices_list, xla::ifrt::MemoryKind(), + ifrt_shape, shard_shapes); // TODO: Attach HloSharding instead when it is supported - // std::unique_ptr ifrt_sharding = xla::ifrt::HloSharding::Create( + // std::unique_ptr ifrt_sharding = + // xla::ifrt::HloSharding::Create( // devices_list, // xla::ifrt::MemoryKind(), // xla::HloSharding::FromProto(sharding).value() // ); - tsl::RCReference sharded_array = client_->AssembleArrayFromSingleDeviceArrays( - ifrt_shape, std::move(ifrt_sharding), absl::MakeSpan(arrays), xla::ifrt::ArrayCopySemantics::kAlwaysCopy).value(); + tsl::RCReference sharded_array = + client_ + ->AssembleArrayFromSingleDeviceArrays( + ifrt_shape, std::move(ifrt_sharding), absl::MakeSpan(arrays), + xla::ifrt::ArrayCopySemantics::kAlwaysCopy) + .value(); return std::make_shared(device, shape, sharded_array, sharding); } @@ -311,14 +322,13 @@ ComputationClient::DataPtr IfrtComputationClient::CopyToDevice( tsl::RCReference IfrtComputationClient::ReplicateShardedData( const std::shared_ptr handle) { - if (handle->buffer->sharding().devices().size() == 1) { return handle->buffer; } XLA_COUNTER("ReplicateShardedData", 1); TF_VLOG(1) << "ReplicateShardedData (handle=" << handle->GetHandle() - << ", shape=" << handle->shape() << ")"; + << ", shape=" << handle->shape() << ")"; // TODO: handle replicated data xla::XlaBuilder builder("ReplicateShardedData"); xla::Shape shape = handle->shape(); @@ -332,35 +342,35 @@ tsl::RCReference IfrtComputationClient::ReplicateShardedData( xla::ConstantR0(&builder, 0), shape.element_type()); xla::XlaOp y = xla::Add(x, scalar_zero_op); auto instruction = XlaBuilderFriend::GetInstruction(y); - *instruction->mutable_sharding() = - xla::HloSharding::Replicate().ToProto(); + *instruction->mutable_sharding() = xla::HloSharding::Replicate().ToProto(); xla::XlaComputation computation = ConsumeValue(builder.Build(/*remove_dynamic_dimensions=*/false)); - xla::ProgramShape program_shape = - ConsumeValue(computation.GetProgramShape()); + xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); std::string device = GetDefaultDevice(); - std::vector - instances; + std::vector instances; instances.push_back({std::move(computation), device, - GetCompilationDevices(device, {}), &shape, - /*should_wrap_parameter=*/false, - /*is_sharded=*/true, - /*allow_spmd_sharding_propagation_to_output=*/false}); + GetCompilationDevices(device, {}), &shape, + /*should_wrap_parameter=*/false, + /*is_sharded=*/true, + /*allow_spmd_sharding_propagation_to_output=*/false}); std::vector< std::shared_ptr> computations = Compile(std::move(instances)); - XLA_CHECK_EQ(handle->buffer->sharding().devices().size(), GetLocalDevices().size()); + XLA_CHECK_EQ(handle->buffer->sharding().devices().size(), + GetLocalDevices().size()); torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; - auto sharded_results = - ExecuteReplicated(*computations.front(), {{handle}}, - GetLocalDevices(), execute_options); - auto replicated_output = std::dynamic_pointer_cast(sharded_results[0])->buffer->FullyReplicatedShard(xla::ifrt::ArrayCopySemantics::kAlwaysCopy); + auto sharded_results = ExecuteReplicated(*computations.front(), {{handle}}, + GetLocalDevices(), execute_options); + auto replicated_output = + std::dynamic_pointer_cast(sharded_results[0]) + ->buffer->FullyReplicatedShard( + xla::ifrt::ArrayCopySemantics::kAlwaysCopy); // TODO: sanity check outputs return *replicated_output; } @@ -377,7 +387,8 @@ std::vector IfrtComputationClient::TransferFromServer( // Use XLA replication to reassemble the sharded data. If input handle // is not sharded, then it is a no-op. auto ifrt_data = std::dynamic_pointer_cast(handle); - tsl::RCReference replicated_array = ReplicateShardedData(ifrt_data); + tsl::RCReference replicated_array = + ReplicateShardedData(ifrt_data); // TODO: handle dynamic shapes auto& literal = literals.emplace_back( @@ -454,8 +465,8 @@ std::vector IfrtComputationClient::Compile( mlir::MLIRContext context; mlir::ModuleOp mlir_module = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); - torch_xla::runtime::ConvertHloToStableHlo( - instance.computation.mutable_proto(), &mlir_module); + torch_xla::ConvertHloToStableHlo(instance.computation.mutable_proto(), + &mlir_module); std::unique_ptr executable = ConsumeValue(client_->GetDefaultCompiler()->Compile( std::make_unique(std::move(mlir_module)), @@ -504,17 +515,17 @@ IfrtComputationClient::ExecuteReplicated( const IfrtComputation& ifrt_computation = dynamic_cast(computation); - std::vector> argument_handles(arguments.size()); + std::vector> argument_handles( + arguments.size()); { absl::BlockingCounter counter(arguments.size()); - pool_.ParallelFor(arguments.size(), 10000, - [&](int64_t start, int64_t end) { - for (int32_t i = start; i < end; ++i) { - auto ifrt_data = std::dynamic_pointer_cast(arguments[i]); - argument_handles[i] = ifrt_data->buffer; - counter.DecrementCount(); - } - }); + pool_.ParallelFor(arguments.size(), 10000, [&](int64_t start, int64_t end) { + for (int32_t i = start; i < end; ++i) { + auto ifrt_data = std::dynamic_pointer_cast(arguments[i]); + argument_handles[i] = ifrt_data->buffer; + counter.DecrementCount(); + } + }); counter.Wait(); } @@ -532,15 +543,17 @@ IfrtComputationClient::ExecuteReplicated( xla::ifrt::LoadedExecutable::ExecuteResult result = ifrt_computation.executable - ->Execute(absl::MakeSpan(argument_handles), execute_options, std::nullopt) + ->Execute(absl::MakeSpan(argument_handles), execute_options, + std::nullopt) .value(); - result.status.OnReady( - std::move([timed, op_tracker = std::move(op_tracker)]( - xla::Status status) mutable { - timed.reset(); - TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady finished with status " << status; - })); + result.status.OnReady(std::move([timed, op_tracker = std::move(op_tracker)]( + xla::Status status) mutable { + timed.reset(); + TF_VLOG(3) + << "ExecuteReplicated returned_future->OnReady finished with status " + << status; + })); auto outputs = result.outputs; @@ -551,14 +564,13 @@ IfrtComputationClient::ExecuteReplicated( std::vector data_handles(outputs.size()); { absl::BlockingCounter counter(outputs.size()); - pool_.ParallelFor(outputs.size(), 10000, - [&](int64_t start, int64_t end) { - for (int32_t i = start; i < end; ++i) { - data_handles[i] = - std::make_shared(spmd_device_str, outputs[i], output_shardings[i]); - counter.DecrementCount(); - } - }); + pool_.ParallelFor(outputs.size(), 10000, [&](int64_t start, int64_t end) { + for (int32_t i = start; i < end; ++i) { + data_handles[i] = std::make_shared( + spmd_device_str, outputs[i], output_shardings[i]); + counter.DecrementCount(); + } + }); counter.Wait(); } diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 468c1050783..0ecd36ce387 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -64,8 +64,7 @@ class IfrtComputationClient : public ComputationClient { const ExecuteComputationOptions& options) override; std::vector ExecuteReplicated( - const Computation& computation, - const absl::Span arguments, + const Computation& computation, const absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) override; @@ -133,14 +132,18 @@ class IfrtComputationClient : public ComputationClient { IfrtData(std::string device, xla::Shape device_shape, tsl::RCReference buffer, std::optional sharding = std::nullopt) - : Data(std::move(device), std::move(device_shape)), buffer(buffer), sharding_(sharding) {} + : Data(std::move(device), std::move(device_shape)), + buffer(buffer), + sharding_(sharding) {} - IfrtData(std::string device, tsl::RCReference buffer, std::optional sharding = std::nullopt) + IfrtData(std::string device, tsl::RCReference buffer, + std::optional sharding = std::nullopt) : Data(std::move(device), xla::ShapeUtil::MakeShape( xla::ifrt::ToPrimitiveType(buffer->dtype()).value(), buffer->shape().dims())), - buffer(buffer), sharding_(sharding) {} + buffer(buffer), + sharding_(sharding) {} Handle GetHandle() override { XLA_CHECK(HasValue()) @@ -165,7 +168,7 @@ class IfrtComputationClient : public ComputationClient { ss << " Data Device: " << device() << "\n"; ss << " Data Shape: " << shape().ToString() << "\n"; ss << " OpSharding: " - << xla::HloSharding::FromProto(*sharding_)->ToString() << "\n"; + << xla::HloSharding::FromProto(*sharding_)->ToString() << "\n"; ss << " NumShards: " << buffer->sharding().devices().size() << "\n"; } else { ss << "XLAData: \n"; diff --git a/torch_xla/csrc/runtime/initialize_pjrt.cc b/torch_xla/csrc/runtime/initialize_pjrt.cc index 160b8ae21d4..12175d0c6d2 100644 --- a/torch_xla/csrc/runtime/initialize_pjrt.cc +++ b/torch_xla/csrc/runtime/initialize_pjrt.cc @@ -2,14 +2,14 @@ #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/profiler.h" #include "torch_xla/csrc/runtime/sys_util.h" +#include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/xla_coordinator.h" -#include "xla/pjrt/pjrt_api.h" -#include "xla/pjrt/pjrt_c_api_client.h" -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/distributed/distributed.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" -#include "torch_xla/csrc/runtime/tf_logging.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/pjrt/pjrt_c_api_client.h" +#include "xla/pjrt/tfrt_cpu_pjrt_client.h" namespace torch_xla { namespace runtime { @@ -33,9 +33,10 @@ xla::GpuAllocatorConfig GetGpuAllocatorConfig() { return allocator_config; } -} +} // namespace -std::unique_ptr InitializePjRt(const std::string& device_type) { +std::unique_ptr InitializePjRt( + const std::string& device_type) { std::unique_ptr client; if (device_type == "CPU") { @@ -129,5 +130,5 @@ std::unique_ptr InitializePjRt(const std::string& device_type) return std::move(client); } -} -} +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/initialize_pjrt.h b/torch_xla/csrc/runtime/initialize_pjrt.h index 395deac1182..57515a5fe0d 100644 --- a/torch_xla/csrc/runtime/initialize_pjrt.h +++ b/torch_xla/csrc/runtime/initialize_pjrt.h @@ -1,7 +1,6 @@ #ifndef XLA_CLIENT_INITIALIZE_PJRT_CLIENT_H_ #define XLA_CLIENT_INITIALIZE_PJRT_CLIENT_H_ - #include "xla/pjrt/pjrt_client.h" namespace torch_xla { @@ -10,6 +9,6 @@ namespace runtime { std::unique_ptr InitializePjRt(const std::string& device_type); } -} +} // namespace torch_xla -#endif // XLA_CLIENT_INITIALIZE_PJRT_H_ +#endif // XLA_CLIENT_INITIALIZE_PJRT_H_ From c0e1ce89a6196ec22578f74177c9f37746ef2468 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 1 Dec 2023 18:01:00 +0000 Subject: [PATCH 20/33] for coordinator init --- torch_xla/csrc/runtime/ifrt_computation_client.cc | 6 ++++-- torch_xla/csrc/runtime/initialize_pjrt.cc | 11 ++++++----- torch_xla/csrc/runtime/initialize_pjrt.h | 3 ++- torch_xla/csrc/runtime/pjrt_computation_client.cc | 2 +- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 5a6dac6c873..0fd7fa9dade 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -96,8 +96,10 @@ std::vector IfrtComputationClient::PjRtDevicesToString( IfrtComputationClient::IfrtComputationClient() { std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, ""); - client_ = - xla::ifrt::PjRtClient::Create(std::move(InitializePjRt(device_type))); + std::unique_ptr pjrt_client; + std::tie(pjrt_client, coordinator_) = std::move(InitializePjRt(device_type)); + + client_ = xla::ifrt::PjRtClient::Create(std::move(pjrt_client)); // PjRtDevice IDs are not guaranteed to be dense, so we need to track // a device's global ordinal separately from its device ID. Order the diff --git a/torch_xla/csrc/runtime/initialize_pjrt.cc b/torch_xla/csrc/runtime/initialize_pjrt.cc index 12175d0c6d2..4e5f6ba7a1d 100644 --- a/torch_xla/csrc/runtime/initialize_pjrt.cc +++ b/torch_xla/csrc/runtime/initialize_pjrt.cc @@ -35,9 +35,10 @@ xla::GpuAllocatorConfig GetGpuAllocatorConfig() { } // namespace -std::unique_ptr InitializePjRt( - const std::string& device_type) { +std::tuple, std::unique_ptr> +InitializePjRt(const std::string& device_type) { std::unique_ptr client; + std::unique_ptr coordinator; if (device_type == "CPU") { TF_VLOG(1) << "Initializing PjRt CPU client..."; @@ -78,10 +79,10 @@ std::unique_ptr InitializePjRt( std::make_optional>(std::set{local_process_rank}); if (global_world_size > 1) { // Use the XlaCoordinator as the distributed key-value store. - coordinator_ = std::make_unique( + coordinator = std::make_unique( global_process_rank, global_world_size, master_addr, port); std::shared_ptr distributed_client = - coordinator_->GetClient(); + coordinator->GetClient(); std::string key_prefix = "gpu:"; kv_get = [distributed_client, key_prefix]( std::string_view k, @@ -127,7 +128,7 @@ std::unique_ptr InitializePjRt( XLA_CHECK(client.get() != nullptr); - return std::move(client); + return {std::move(client), std::move(coordinator)}; } } // namespace runtime diff --git a/torch_xla/csrc/runtime/initialize_pjrt.h b/torch_xla/csrc/runtime/initialize_pjrt.h index 57515a5fe0d..012927fe474 100644 --- a/torch_xla/csrc/runtime/initialize_pjrt.h +++ b/torch_xla/csrc/runtime/initialize_pjrt.h @@ -6,7 +6,8 @@ namespace torch_xla { namespace runtime { -std::unique_ptr InitializePjRt(const std::string& device_type); +std::tuple, std::unique_ptr> +InitializePjRt(const std::string& device_type); } } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index aeb43637742..8d146c65fdf 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -120,7 +120,7 @@ std::vector PjRtComputationClient::PjRtDevicesToString( PjRtComputationClient::PjRtComputationClient() { std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, ""); - client_ = std::move(InitializePjRt(device_type)); + std::tie(client_, coordinator_) = std::move(InitializePjRt(device_type)); // PjRtDevice IDs are not guaranteed to be dense, so we need to track // a device's global ordinal separately from its device ID. Order the From 1812afa87c4e090294f38df36d11e2b7a0ae6723 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 1 Dec 2023 18:03:18 +0000 Subject: [PATCH 21/33] remove extra `std::move`s --- torch_xla/csrc/runtime/ifrt_computation_client.cc | 2 +- torch_xla/csrc/runtime/pjrt_computation_client.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 0fd7fa9dade..e0df38e3753 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -97,7 +97,7 @@ std::vector IfrtComputationClient::PjRtDevicesToString( IfrtComputationClient::IfrtComputationClient() { std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, ""); std::unique_ptr pjrt_client; - std::tie(pjrt_client, coordinator_) = std::move(InitializePjRt(device_type)); + std::tie(pjrt_client, coordinator_) = InitializePjRt(device_type); client_ = xla::ifrt::PjRtClient::Create(std::move(pjrt_client)); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 8d146c65fdf..f5fdec51e80 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -120,7 +120,7 @@ std::vector PjRtComputationClient::PjRtDevicesToString( PjRtComputationClient::PjRtComputationClient() { std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, ""); - std::tie(client_, coordinator_) = std::move(InitializePjRt(device_type)); + std::tie(client_, coordinator_) = InitializePjRt(device_type); // PjRtDevice IDs are not guaranteed to be dense, so we need to track // a device's global ordinal separately from its device ID. Order the From 135e68d83d318576193f5700571ce1759d704323 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 1 Dec 2023 18:32:38 +0000 Subject: [PATCH 22/33] unit test --- torch_xla/csrc/runtime/BUILD | 74 +++++++++++------ .../csrc/runtime/ifrt_computation_client.cc | 29 ++----- .../runtime/ifrt_computation_client_test.cc | 81 +++++++++++++++++++ 3 files changed, 138 insertions(+), 46 deletions(-) create mode 100644 torch_xla/csrc/runtime/ifrt_computation_client_test.cc diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 998ba13e30e..30f75680bec 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -456,28 +456,52 @@ ptxla_cc_test( ], ) -# disable for now since it is flaky on the upstream test. -# ptxla_cc_test( -# name = "pjrt_computation_client_test", -# srcs = ["pjrt_computation_client_test.cc"], -# deps = [ -# ":computation_client", -# ":pjrt_computation_client", -# ":tensor_source", -# "@xla//xla:literal", -# "@xla//xla:literal_util", -# "@xla//xla:shape_util", -# "@xla//xla:status", -# "@xla//xla:statusor", -# "@xla//xla/client:xla_builder", -# "@xla//xla/client:xla_computation", -# "@xla//xla/tests:literal_test_util", -# "@xla//xla/tools:hlo_module_loader", -# "@tsl//tsl/lib/core:status_test_util", -# "@tsl//tsl/platform:env", -# "@tsl//tsl/platform:errors", -# "@tsl//tsl/platform:logging", -# "@tsl//tsl/platform:test", -# "@tsl//tsl/platform:test_main", -# ], -# ) +ptxla_cc_test( + name = "pjrt_computation_client_test", + srcs = ["pjrt_computation_client_test.cc"], + deps = [ + ":computation_client", + ":pjrt_computation_client", + ":tensor_source", + "@xla//xla:literal", + "@xla//xla:literal_util", + "@xla//xla:shape_util", + "@xla//xla:status", + "@xla//xla:statusor", + "@xla//xla/client:xla_builder", + "@xla//xla/client:xla_computation", + "@xla//xla/tests:literal_test_util", + "@xla//xla/tools:hlo_module_loader", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + +ptxla_cc_test( + name = "ifrt_computation_client_test", + srcs = ["ifrt_computation_client_test.cc"], + deps = [ + ":computation_client", + ":ifrt_computation_client", + ":tensor_source", + "@xla//xla:literal", + "@xla//xla:literal_util", + "@xla//xla:shape_util", + "@xla//xla:status", + "@xla//xla:statusor", + "@xla//xla/client:xla_builder", + "@xla//xla/client:xla_computation", + "@xla//xla/tests:literal_test_util", + "@xla//xla/tools:hlo_module_loader", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index e0df38e3753..9a8ae572cd9 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -445,22 +445,7 @@ std::vector IfrtComputationClient::Compile( compile_options.executable_build_options.set_device_assignment( device_assignment); } else { - // TODO(wcromar): set compile_options.argument_layouts, enable strict - // shapes - compile_options.executable_build_options.set_num_partitions(1); - compile_options.executable_build_options.set_num_replicas( - client_->device_count()); - compile_options.parameter_is_tupled_arguments = - instance.parameter_is_tupled_arguments; - - xla::DeviceAssignment device_assignment(client_->device_count(), 1); - // DeviceAssignment values must be the PjRtDevice ID, so we need to - // unwind the global ordinal mapping. - for (const auto& [device_id, global_ordinal] : global_ordinals_) { - device_assignment(global_ordinal, 0) = device_id; - } - compile_options.executable_build_options.set_device_assignment( - device_assignment); + XLA_ERROR() << "Only SPMD compilation is supported"; } // Convert HLO to StableHLO for Ifrt client compilation. @@ -476,14 +461,13 @@ std::vector IfrtComputationClient::Compile( StableHloCompileCounter()->AddValue(1); const auto& hlo_modules = ConsumeValue(executable->GetHloModules()); - xla::HloComputation* hlo_computation = hlo_modules[0]->entry_computation(); - std::shared_ptr pjrt_computation = + std::shared_ptr ifrt_computation = std::make_shared( std::move(xla::XlaComputation(hlo_modules[0]->ToProto())), instance.devices, std::move(executable)); - computations.push_back(pjrt_computation); + computations.push_back(ifrt_computation); CreateCompileHandlesCounter()->AddValue(1); } @@ -559,8 +543,11 @@ IfrtComputationClient::ExecuteReplicated( auto outputs = result.outputs; - XLA_CHECK(ifrt_computation.output_shardings_.has_value()); - auto& output_shardings = *(ifrt_computation.output_shardings_); + const std::vector& output_shardings = + ifrt_computation.output_shardings_ + ? *ifrt_computation.output_shardings_ + : std::vector(outputs.size(), + xla::HloSharding::Replicate().ToProto()); XLA_CHECK_EQ(output_shardings.size(), outputs.size()); std::vector data_handles(outputs.size()); diff --git a/torch_xla/csrc/runtime/ifrt_computation_client_test.cc b/torch_xla/csrc/runtime/ifrt_computation_client_test.cc new file mode 100644 index 00000000000..21194432dd4 --- /dev/null +++ b/torch_xla/csrc/runtime/ifrt_computation_client_test.cc @@ -0,0 +1,81 @@ +#include "torch_xla/csrc/runtime/ifrt_computation_client.h" + +#include + +#include +#include +#include + +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/tensor_source.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/test.h" +#include "xla/client/xla_builder.h" +#include "xla/client/xla_computation.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/tests/literal_test_util.h" + +namespace torch_xla { +namespace runtime { + +tsl::StatusOr MakeComputation() { + xla::Shape input_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {2, 2}); + xla::XlaBuilder builder("AddComputation"); + xla::XlaOp x = xla::Parameter(&builder, 0, input_shape, "x"); + xla::XlaOp y = xla::Parameter(&builder, 1, input_shape, "y"); + xla::XlaOp sum = xla::Add(x, y); + return builder.Build(); +} + +TEST(PjRtComputationClientTest, Init) { + // Get a CPU client. + tsl::setenv("PJRT_DEVICE", "CPU", true); + auto client = std::make_unique(); + std::string device = client->GetDefaultDevice(); + + // Compose a computation. + auto shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 2}); + std::vector instances; + instances.push_back(ComputationClient::CompileInstance( + std::move(MakeComputation().value()), device, + client->GetCompilationDevices(device, client->GetLocalDevices()), &shape, + /*parameter_is_tupled_arguments=*/false, /*is_sharded=*/true)); + + // Prepare inputs. + xla::Literal literal_x = + xla::LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}}); + xla::Literal literal_y = + xla::LiteralUtil::CreateR2({{5.0f, 6.0f}, {7.0f, 8.0f}}); + + // Compile the graph. + std::vector computations = + client->Compile(std::move(instances)); + + // Copy inputs to device. + ComputationClient::ExecuteReplicatedOptions options{}; + std::vector> args = { + std::make_shared(std::move(literal_x), device), + std::make_shared(std::move(literal_y), device)}; + + // Execute the graph. + std::vector results = client->ExecuteReplicated( + *computations[0], client->TransferToServer(absl::MakeConstSpan(args)), + {device}, options); + + // Copy the output from device back to host and assert correctness.. + ASSERT_EQ(results.size(), 1); + auto result_literals = client->TransferFromServer(results); + ASSERT_THAT(result_literals, ::testing::SizeIs(1)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal( + xla::LiteralUtil::CreateR2({{6.0f, 8.0f}, {10.0f, 12.0f}}), + result_literals[0])); +} + +} // namespace runtime +} // namespace torch_xla From 8fce00562401c72d8621cb77f5c03718902e8e61 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 1 Dec 2023 18:52:19 +0000 Subject: [PATCH 23/33] tune parallelfors --- .../csrc/runtime/ifrt_computation_client.cc | 39 ++++++++++++------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 9a8ae572cd9..862a20585d2 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -505,13 +505,19 @@ IfrtComputationClient::ExecuteReplicated( arguments.size()); { absl::BlockingCounter counter(arguments.size()); - pool_.ParallelFor(arguments.size(), 10000, [&](int64_t start, int64_t end) { - for (int32_t i = start; i < end; ++i) { - auto ifrt_data = std::dynamic_pointer_cast(arguments[i]); - argument_handles[i] = ifrt_data->buffer; - counter.DecrementCount(); - } - }); + + // Cost to handle one input argument. See tsl::ThreadPool::ParallelFor + // documentation + static const int32_t argument_handle_cost_ns = 1000; + pool_.ParallelFor(arguments.size(), argument_handle_cost_ns, + [&](int64_t start, int64_t end) { + for (int32_t i = start; i < end; ++i) { + auto ifrt_data = + std::dynamic_pointer_cast(arguments[i]); + argument_handles[i] = ifrt_data->buffer; + counter.DecrementCount(); + } + }); counter.Wait(); } @@ -553,13 +559,18 @@ IfrtComputationClient::ExecuteReplicated( std::vector data_handles(outputs.size()); { absl::BlockingCounter counter(outputs.size()); - pool_.ParallelFor(outputs.size(), 10000, [&](int64_t start, int64_t end) { - for (int32_t i = start; i < end; ++i) { - data_handles[i] = std::make_shared( - spmd_device_str, outputs[i], output_shardings[i]); - counter.DecrementCount(); - } - }); + + // Cost to handle one output. See tsl::ThreadPool::ParallelFor + // documentation. + static const int32_t result_handle_cost_ns = 2000; + pool_.ParallelFor(outputs.size(), result_handle_cost_ns, + [&](int64_t start, int64_t end) { + for (int32_t i = start; i < end; ++i) { + data_handles[i] = std::make_shared( + spmd_device_str, outputs[i], output_shardings[i]); + counter.DecrementCount(); + } + }); counter.Wait(); } From 86883ad3b368bfffcae591b5e2d80988ecc2f18a Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 13 Dec 2023 21:19:05 +0000 Subject: [PATCH 24/33] pjrt -> ifrt --- .../csrc/runtime/ifrt_computation_client.cc | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 862a20585d2..8105b26b75b 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -73,8 +73,8 @@ std::unordered_map build_index_map( } // namespace -std::string IfrtComputationClient::PjRtDeviceToString( - xla::PjRtDevice* const device) const { +std::string IfrtComputationClient::IfrtDeviceToString( + xla::ifrt::Device* const device) const { std::string platform = absl::AsciiStrToUpper(device->client()->platform_name()); int ordinal = global_ordinals_.at(device->id()); @@ -82,13 +82,13 @@ std::string IfrtComputationClient::PjRtDeviceToString( return str; } -std::vector IfrtComputationClient::PjRtDevicesToString( - absl::Span devices) const { +std::vector IfrtComputationClient::IfrtDevicesToString( + absl::Span devices) const { std::vector strs; strs.reserve(devices.size()); for (auto* device : devices) { - strs.push_back(PjRtDeviceToString(device)); + strs.push_back(IfrtDeviceToString(device)); } return strs; @@ -104,7 +104,7 @@ IfrtComputationClient::IfrtComputationClient() { // PjRtDevice IDs are not guaranteed to be dense, so we need to track // a device's global ordinal separately from its device ID. Order the // devices by increasing ID to assign global ordinals. - std::vector ordered_devices(client_->device_count()); + std::vector ordered_devices(client_->device_count()); std::partial_sort_copy(client_->devices().begin(), client_->devices().end(), ordered_devices.begin(), ordered_devices.end(), [](auto& a, auto& b) { return a->id() < b->id(); }); @@ -148,9 +148,9 @@ XlaCoordinator& IfrtComputationClient::GetCoordinator() { void IfrtComputationClient::IfrtData::Assign( const torch::lazy::BackendData& data) { - const IfrtData& pjrt_data = dynamic_cast(data); - if (&pjrt_data != this) { - buffer = pjrt_data.buffer; + const IfrtData& ifrt_data = dynamic_cast(data); + if (&ifrt_data != this) { + buffer = ifrt_data.buffer; } } @@ -179,7 +179,7 @@ std::vector IfrtComputationClient::GetDataShards( for (auto array : arrays) { shards.push_back(std::make_shared( - PjRtDeviceToString(array->sharding().devices()[0]), array)); + IfrtDeviceToString(array->sharding().devices()[0]), array)); } } else { shards.push_back(data); @@ -248,7 +248,7 @@ std::vector IfrtComputationClient::TransferToServer( datas.reserve(tensors.size()); int64_t total_size = 0; for (auto& tensor : tensors) { - xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor->device()); + xla::ifrt::Device* ifrt_device = StringToIfrtDevice(tensor->device()); total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape()); @@ -260,8 +260,8 @@ std::vector IfrtComputationClient::TransferToServer( xla::ifrt::Shape(tensor->dimensions()), tensor->byte_strides(), // TODO: what is MemoryKind? xla::ifrt::SingleDeviceSharding::Create( - pjrt_device, xla::ifrt::MemoryKind()), - xla::PjRtClient::HostBufferSemantics:: + ifrt_device, xla::ifrt::MemoryKind()), + xla::ifrt::Client::HostBufferSemantics:: kImmutableUntilTransferCompletes, [tensor]() { /* frees tensor */ }) .value(); @@ -424,8 +424,7 @@ std::vector IfrtComputationClient::Compile( // TODO(yeounoh) multi-host, multi-slice configurations compile_options.executable_build_options.set_use_spmd_partitioning(true); // We can override the compiler's default behavior to replicate the - // outputs. Setting this to true would wrapping the sharded outputs in - // PjRtShardedData. + // outputs. compile_options.executable_build_options .set_allow_spmd_sharding_propagation_to_output( {instance.allow_spmd_sharding_propagation_to_output}); @@ -583,15 +582,15 @@ size_t IfrtComputationClient::GetNumDevices() const { } std::string IfrtComputationClient::GetDefaultDevice() const { - return PjRtDeviceToString(client_->addressable_devices()[0]); + return IfrtDeviceToString(client_->addressable_devices()[0]); } std::vector IfrtComputationClient::GetLocalDevices() const { - return PjRtDevicesToString(client_->addressable_devices()); + return IfrtDevicesToString(client_->addressable_devices()); } std::vector IfrtComputationClient::GetAllDevices() const { - return PjRtDevicesToString(client_->devices()); + return IfrtDevicesToString(client_->devices()); } int IfrtComputationClient::GetNumProcesses() const { @@ -606,7 +605,7 @@ int IfrtComputationClient::GetNumProcesses() const { const absl::flat_hash_map< std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& IfrtComputationClient::GetDeviceAttributes(const std::string& device) { - return IfrtComputationClient::StringToPjRtDevice(device)->Attributes(); + return IfrtComputationClient::StringToIfrtDevice(device)->Attributes(); } void IfrtComputationClient::SetReplicationDevices( @@ -619,12 +618,12 @@ IfrtComputationClient::GetReplicationDevices() { return replication_devices_; } -xla::PjRtDevice* IfrtComputationClient::StringToPjRtDevice( +xla::ifrt::Device* IfrtComputationClient::StringToIfrtDevice( const std::string& device) { XLA_CHECK(string_to_device_.find(device) != string_to_device_.end()) << "Unknown device " << device; - xla::PjRtDevice* pjrt_device = string_to_device_[device]; - return pjrt_device; + xla::ifrt::Device* ifrt_device = string_to_device_[device]; + return ifrt_device; } void IfrtComputationClient::WaitDeviceOps( @@ -635,7 +634,7 @@ void IfrtComputationClient::WaitDeviceOps( } std::map IfrtComputationClient::GetMetrics() const { - // TODO(jonbolin): Add any PJRt-client-specific metrics here + // TODO(jonbolin): Add any Ifrt-client-specific metrics here return {}; } From 07a9efbff77a9c6448146b27672961ef140ba8f9 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 13 Dec 2023 21:51:29 +0000 Subject: [PATCH 25/33] fix rebasing issues --- .../csrc/runtime/ifrt_computation_client.cc | 2 +- .../csrc/runtime/ifrt_computation_client.h | 25 ++++++++++++++++--- torch_xla/csrc/runtime/runtime.cc | 1 - 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 8105b26b75b..df96bb6d602 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -110,7 +110,7 @@ IfrtComputationClient::IfrtComputationClient() { [](auto& a, auto& b) { return a->id() < b->id(); }); for (auto* device : ordered_devices) { global_ordinals_[device->id()] = global_ordinals_.size(); - std::string device_str = PjRtDeviceToString(device); + std::string device_str = IfrtDeviceToString(device); string_to_device_.emplace(device_str, device); } diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 0ecd36ce387..7186aeec46e 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -101,12 +101,28 @@ class IfrtComputationClient : public ComputationClient { bool CoordinatorInitialized() const override; + torch::lazy::hash_t HashCompilationEnv() override { + return comp_env_hash_; + } + // NOT IMPLEMENTED MemoryInfo GetMemoryInfo(const std::string& device) override { XLA_ERROR() << __FUNCTION__ << " not implemented"; }; + std::string SerializeComputation( + const ComputationPtr computation) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + ComputationPtr DeserializeComputation( + const std::string& serialized) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + + private: std::shared_ptr client_; std::unique_ptr coordinator_; @@ -118,12 +134,13 @@ class IfrtComputationClient : public ComputationClient { OperationManager operation_manager_; tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool( tsl::Env::Default(), "ifrt", std::thread::hardware_concurrency()); + torch::lazy::hash_t comp_env_hash_; - xla::PjRtDevice* StringToPjRtDevice(const std::string& device); + xla::ifrt::Device* StringToIfrtDevice(const std::string& device); - std::string PjRtDeviceToString(xla::PjRtDevice* const device) const; - std::vector PjRtDevicesToString( - absl::Span devices) const; + std::string IfrtDeviceToString(xla::ifrt::Device* const device) const; + std::vector IfrtDevicesToString( + absl::Span devices) const; struct IfrtData : public Data { IfrtData(std::string device, xla::Shape device_shape) diff --git a/torch_xla/csrc/runtime/runtime.cc b/torch_xla/csrc/runtime/runtime.cc index 4a4de61254c..feb2a0844c6 100644 --- a/torch_xla/csrc/runtime/runtime.cc +++ b/torch_xla/csrc/runtime/runtime.cc @@ -21,7 +21,6 @@ ComputationClient* GetComputationClient() { std::unique_ptr client; static bool use_ifrt = sys_util::GetEnvBool("XLA_USE_IFRT", false); - ComputationClient* client; if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") { if (use_ifrt) { client = std::make_unique(); From 3bfaa3a1cb7ec0fd4bfb28679db16a2e55464713 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 13 Dec 2023 21:54:24 +0000 Subject: [PATCH 26/33] formatting --- torch_xla/csrc/runtime/ifrt_computation_client.h | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 7186aeec46e..2716c05feec 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -101,9 +101,7 @@ class IfrtComputationClient : public ComputationClient { bool CoordinatorInitialized() const override; - torch::lazy::hash_t HashCompilationEnv() override { - return comp_env_hash_; - } + torch::lazy::hash_t HashCompilationEnv() override { return comp_env_hash_; } // NOT IMPLEMENTED @@ -111,8 +109,7 @@ class IfrtComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; }; - std::string SerializeComputation( - const ComputationPtr computation) override { + std::string SerializeComputation(const ComputationPtr computation) override { XLA_ERROR() << __FUNCTION__ << " not implemented"; } @@ -121,8 +118,6 @@ class IfrtComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; } - - private: std::shared_ptr client_; std::unique_ptr coordinator_; From a71a9d588b14924c7d6b1c9c67e9459de088c991 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 13 Dec 2023 22:08:07 +0000 Subject: [PATCH 27/33] fix timer and comp env hash --- torch_xla/csrc/runtime/BUILD | 1 + .../csrc/runtime/ifrt_computation_client.cc | 40 ++++++++++++++++++- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 30f75680bec..664d26e7674 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -199,6 +199,7 @@ cc_library( hdrs = ["initialize_pjrt.h"], deps = [ ":debug_macros", + ":env_hash", ":env_vars", ":profiler", ":sys_util", diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index df96bb6d602..6bb71b1e59c 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -9,6 +9,7 @@ #include "absl/types/span.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/env_hash.h" #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/initialize_pjrt.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" @@ -71,6 +72,40 @@ std::unordered_map build_index_map( return device_index; } +torch::lazy::hash_t hash_comp_env( + std::shared_ptr client, + std::vector& ordered_devices) { + torch::lazy::hash_t hash = hash::HashXlaEnvVars(); + // Whether or not SPMD mode is active should influence the hash. + hash = torch::lazy::HashCombine(hash, UseVirtualDevice()); + auto topology_desc = client->GetTopologyForDevices(ordered_devices); + if (topology_desc.ok()) { + // Some backends support a topology description which provides a better + // view of the specific compilation environment. + auto serialized = topology_desc.value()->Serialize(); + if (serialized.ok()) { + return torch::lazy::HashCombine( + hash, + torch::lazy::DataHash(serialized->data(), serialized->length())); + } + // If serialization fails, fallthrough to the manual approach. + } + std::string platform_name(client->platform_name()); + std::string platform_version(client->platform_version()); + hash = torch::lazy::HashCombine( + hash, torch::lazy::StringHash(platform_name.c_str())); + // platform_version incorporates libtpu version and hardware type. + hash = torch::lazy::HashCombine( + hash, torch::lazy::StringHash(platform_version.c_str())); + // Include global devices in the hash, ensuring order is consistent. + for (auto& device : ordered_devices) { + std::string device_str(device->ToString()); + hash = torch::lazy::HashCombine( + hash, torch::lazy::StringHash(device_str.c_str())); + } + return hash; +} + } // namespace std::string IfrtComputationClient::IfrtDeviceToString( @@ -113,6 +148,7 @@ IfrtComputationClient::IfrtComputationClient() { std::string device_str = IfrtDeviceToString(device); string_to_device_.emplace(device_str, device); } + comp_env_hash_ = hash_comp_env(client_, ordered_devices); auto tracked_devices = GetLocalDevices(); tracked_devices.emplace_back(spmd_device_str); @@ -241,7 +277,7 @@ std::optional IfrtComputationClient::GetDataSharding( std::vector IfrtComputationClient::TransferToServer( absl::Span> tensors) { - metrics::TimedSection timed(TransferToServerMetric()); + auto timed = std::make_shared(TransferToServerMetric()); tsl::profiler::TraceMe activity("IfrtComputationClient::TransferToServer", tsl::profiler::TraceMeLevel::kInfo); std::vector datas; @@ -263,7 +299,7 @@ std::vector IfrtComputationClient::TransferToServer( ifrt_device, xla::ifrt::MemoryKind()), xla::ifrt::Client::HostBufferSemantics:: kImmutableUntilTransferCompletes, - [tensor]() { /* frees tensor */ }) + [tensor, timed]() { /* frees tensor and timer */ }) .value(); ComputationClient::DataPtr data = From 0d2c842505f3b12ac700eb8741266d97763ad58e Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 13 Dec 2023 22:10:00 +0000 Subject: [PATCH 28/33] formatting --- torch_xla/csrc/runtime/ifrt_computation_client.cc | 3 ++- torch_xla/csrc/runtime/ifrt_computation_client.h | 3 --- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 6bb71b1e59c..5c6c1f101d2 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -277,7 +277,8 @@ std::optional IfrtComputationClient::GetDataSharding( std::vector IfrtComputationClient::TransferToServer( absl::Span> tensors) { - auto timed = std::make_shared(TransferToServerMetric()); + auto timed = + std::make_shared(TransferToServerMetric()); tsl::profiler::TraceMe activity("IfrtComputationClient::TransferToServer", tsl::profiler::TraceMeLevel::kInfo); std::vector datas; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 2716c05feec..1c7e44bd6ee 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -43,9 +43,6 @@ class IfrtComputationClient : public ComputationClient { std::vector TransferToServer( absl::Span> tensors) override; - // Use XLA replication to re-assemble the sharded data. - // DataPtr ReplicateShardedData(const DataPtr& handle); - std::vector TransferFromServer( absl::Span handles) override; From 5b26aa4bbd6caaee5515fc72126879703aac467e Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 13 Dec 2023 22:11:46 +0000 Subject: [PATCH 29/33] remove dead code --- .../csrc/runtime/ifrt_computation_client.cc | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 5c6c1f101d2..18ae0f2d997 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -41,22 +41,6 @@ namespace { static const std::string spmd_device_str = "SPMD:0"; -// Initializes a distributed runtime client if dist_service_addr is specified -std::shared_ptr -MaybeInitializeDistributedRuntimeClient(int local_rank, - std::string dist_service_addr) { - std::shared_ptr client; - if (!dist_service_addr.empty()) { - xla::DistributedRuntimeClient::Options options; - /* TODO(jonbolin): Use global rank for multi-host setup */ - options.node_id = local_rank; - client = xla::GetDistributedRuntimeClient(dist_service_addr, options); - XLA_CHECK(client->Connect().ok()) - << "Failed to initialize distributed runtime client"; - } - return std::move(client); -} - // Builds a map from the device's global ordinal to its index in the `devices` // array. std::unordered_map build_index_map( From 0df165ae34de3e525c82cc8d482944f9de2629c9 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 13 Dec 2023 22:18:43 +0000 Subject: [PATCH 30/33] Move SPMD string constant --- torch_xla/csrc/runtime/computation_client.h | 2 ++ torch_xla/csrc/runtime/ifrt_computation_client.cc | 2 -- torch_xla/csrc/runtime/pjrt_computation_client.cc | 2 -- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 6e87c13b193..97f633a39f4 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -375,6 +375,8 @@ class ComputationClient { static int64_t GetDeviceOrdinal(const std::string& device); protected: + static constexpr auto spmd_device_str = "SPMD:0"; + // Metrics common to all client interfaces. static metrics::Metric* TransferToServerMetric(); static metrics::Metric* TransferToServerTransformMetric(); diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 18ae0f2d997..cc8f2d89eed 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -39,8 +39,6 @@ namespace runtime { namespace { -static const std::string spmd_device_str = "SPMD:0"; - // Builds a map from the device's global ordinal to its index in the `devices` // array. std::unordered_map build_index_map( diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index f5fdec51e80..d12eb6dd092 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -37,8 +37,6 @@ namespace runtime { namespace { -static const std::string spmd_device_str = "SPMD:0"; - // Builds a map from the device's global ordinal to its index in the `devices` // array. std::unordered_map build_index_map( From 723191fe93534c509f7e9d74312c6daf87d620f4 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 13 Dec 2023 22:51:18 +0000 Subject: [PATCH 31/33] fix compile error --- torch_xla/csrc/runtime/ifrt_computation_client.cc | 4 ++-- torch_xla/csrc/runtime/pjrt_computation_client.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index cc8f2d89eed..add4c2fa1be 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -55,7 +55,7 @@ std::unordered_map build_index_map( } torch::lazy::hash_t hash_comp_env( - std::shared_ptr client, + xla::ifrt::Client* client, std::vector& ordered_devices) { torch::lazy::hash_t hash = hash::HashXlaEnvVars(); // Whether or not SPMD mode is active should influence the hash. @@ -130,7 +130,7 @@ IfrtComputationClient::IfrtComputationClient() { std::string device_str = IfrtDeviceToString(device); string_to_device_.emplace(device_str, device); } - comp_env_hash_ = hash_comp_env(client_, ordered_devices); + comp_env_hash_ = hash_comp_env(client_.get(), ordered_devices); auto tracked_devices = GetLocalDevices(); tracked_devices.emplace_back(spmd_device_str); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index d12eb6dd092..d3fd7f3ce8a 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -62,7 +62,7 @@ xla::Shape host_output_shape(xla::PjRtBuffer* buffer) { } torch::lazy::hash_t hash_comp_env( - std::shared_ptr client, + xla::PjRtClient* client, std::vector& ordered_devices) { torch::lazy::hash_t hash = hash::HashXlaEnvVars(); auto topology_desc = client->GetTopologyDescription(); @@ -132,7 +132,7 @@ PjRtComputationClient::PjRtComputationClient() { std::string device_str = PjRtDeviceToString(device); string_to_device_.emplace(device_str, device); } - comp_env_hash_ = hash_comp_env(client_, ordered_devices); + comp_env_hash_ = hash_comp_env(client_.get(), ordered_devices); auto tracked_devices = GetLocalDevices(); tracked_devices.emplace_back(spmd_device_str); From 9ebc616500760e2bc3869029c764e25bc9151ee6 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 13 Dec 2023 23:05:17 +0000 Subject: [PATCH 32/33] format --- torch_xla/csrc/runtime/pjrt_computation_client.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index d3fd7f3ce8a..1ca27518282 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -62,8 +62,7 @@ xla::Shape host_output_shape(xla::PjRtBuffer* buffer) { } torch::lazy::hash_t hash_comp_env( - xla::PjRtClient* client, - std::vector& ordered_devices) { + xla::PjRtClient* client, std::vector& ordered_devices) { torch::lazy::hash_t hash = hash::HashXlaEnvVars(); auto topology_desc = client->GetTopologyDescription(); if (topology_desc.ok()) { From cc4c93d973edd580492a8990eedecf8ecec3ebec Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 14 Dec 2023 22:59:52 +0000 Subject: [PATCH 33/33] remove SPMD from hash --- torch_xla/csrc/runtime/ifrt_computation_client.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index add4c2fa1be..605826e6e6a 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -58,8 +58,6 @@ torch::lazy::hash_t hash_comp_env( xla::ifrt::Client* client, std::vector& ordered_devices) { torch::lazy::hash_t hash = hash::HashXlaEnvVars(); - // Whether or not SPMD mode is active should influence the hash. - hash = torch::lazy::HashCombine(hash, UseVirtualDevice()); auto topology_desc = client->GetTopologyForDevices(ordered_devices); if (topology_desc.ok()) { // Some backends support a topology description which provides a better