From 075cead381c1e087caef4575bd746dc94f21c091 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 30 Nov 2023 12:55:52 -0800 Subject: [PATCH] Refactor ExecuteReplicated to operate on sharded data directly (#5737) * Refactor ExecuteReplicated to operate on sharded data directly * Remove old handlers * formatting * Improve naming and logging * update docstring * Remove obsolete unit tests * improve comment * Remove slow calls to get output shapes. * fix implicit sharding * remove declarations of input/output handlers * formatting * give everything a manual placeholder sharding * see if CI passes * formatting * Shard parameter and output handling * Use absl::BlockingCounter * formatting * fix merge * Assign valid output shardings * tune and document costs * formatting * implicitly replicate output to match outputhandler * clarify ReplicateShardedData * fix merge --- test/cpp/test_xla_sharding.cpp | 94 --------- torch_xla/csrc/runtime/BUILD | 1 + torch_xla/csrc/runtime/computation_client.h | 21 +- .../csrc/runtime/pjrt_computation_client.cc | 180 ++++++++++-------- .../csrc/runtime/pjrt_computation_client.h | 20 +- torch_xla/csrc/xla_graph_executor.cpp | 24 +-- torch_xla/csrc/xla_sharding_util.cpp | 80 -------- torch_xla/csrc/xla_sharding_util.h | 26 --- 8 files changed, 128 insertions(+), 318 deletions(-) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 1bd39e91783..08beb49b7b2 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -348,100 +348,6 @@ TEST_F(XLAShardingTest, CreateTensorsData) { } } -TEST_F(XLAShardingTest, InputHandler) { - if ((torch_xla::runtime::sys_util::GetEnvString( - torch_xla::runtime::env::kEnvPjRtDevice, "") == "") || - (torch_xla::runtime::GetComputationClient()->GetLocalDevices().size() < - 2)) { - GTEST_SKIP() - << "`PJRT_DEVICE` is not set, with more than 2 local devices, (" - << torch_xla::runtime::GetComputationClient()->GetLocalDevices().size() - << " local devices detected)."; - } - - std::vector tensors(2); - auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat)); - xla::Shape tensor_shape = - CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); - std::fill_n(tensors.begin(), tensors.size(), tensor); - std::vector devices = {"TPU:0", "TPU:1"}; - std::vector shardings = { - nullptr, std::make_shared( - xla::HloSharding::Replicate().ToProto(), tensor_shape)}; - std::vector tensors_data = - CreateTensorsData(tensors, shardings, devices); - - std::vector arguments = - UnwrapXlaData(tensors_data); - auto arguments_by_device = ShardingUtil::InputHandler(arguments, devices); - - auto arg0_dev0 = arguments_by_device[0][0]; - auto arg0_dev1 = arguments_by_device[1][0]; - EXPECT_TRUE(XlaDataValuesEqual(arg0_dev0, arg0_dev1, at::kFloat)); - - auto arg1_dev0 = arguments_by_device[0][1]; - auto arg1_dev1 = arguments_by_device[1][1]; - EXPECT_TRUE(XlaDataValuesEqual(arg1_dev0, arg1_dev1, at::kFloat)); -} - -TEST_F(XLAShardingTest, OutputHandler) { - if ((torch_xla::runtime::sys_util::GetEnvString( - torch_xla::runtime::env::kEnvPjRtDevice, "") == "") || - (torch_xla::runtime::GetComputationClient()->GetLocalDevices().size() < - 2)) { - GTEST_SKIP() - << "`PJRT_DEVICE` is not set, with more than 2 local devices, (" - << torch_xla::runtime::GetComputationClient()->GetLocalDevices().size() - << " local devices detected)."; - } - - std::vector devices = - torch_xla::runtime::GetComputationClient()->GetLocalDevices(); - - // Prepare an input vecotr `outputs` with 2 arguments per device. - std::vector> - outputs; - outputs.reserve(devices.size()); - at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat)); - for (auto device : devices) { - outputs.push_back( - UnwrapXlaData(CreateTensorsData({tensor, tensor}, {device, device}))); - } - - xla::Shape tensor_shape = - CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); - auto sharding_spec = std::make_shared( - xla::HloSharding::Tile1D( - CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()), - devices.size()) - .ToProto(), - tensor_shape); - std::vector sharding_specs{sharding_spec, - sharding_spec}; - - // Shard a PjRtData into a PjRtShardedData. - std::vector sharded_outputs = - ShardingUtil::OutputHandler(outputs, sharding_specs, - /*replicated_output=*/true); - EXPECT_EQ(sharded_outputs.size(), 2); - auto shards = torch_xla::runtime::GetComputationClient()->GetDataShards( - sharded_outputs[0]); - EXPECT_EQ(shards.size(), devices.size()); - EXPECT_FALSE( - xla::Shape::Equal().IgnoreLayout()(shards[0]->shape(), tensor_shape)); - - // Wrap sharded data into a PjRtShardedData with `devices.size()` shards. - std::vector wrapped_outputs = - ShardingUtil::OutputHandler(outputs, sharding_specs, - /*replicated_output=*/false); - EXPECT_EQ(wrapped_outputs.size(), 2); - shards = torch_xla::runtime::GetComputationClient()->GetDataShards( - wrapped_outputs[0]); - EXPECT_EQ(shards.size(), devices.size()); - EXPECT_TRUE( - xla::Shape::Equal().IgnoreLayout()(shards[0]->shape(), tensor_shape)); -} - TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) { xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {4, 4}); int64_t n_devices = diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index fa7e3578729..df2b104876f 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -101,6 +101,7 @@ cc_library( "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/platform/cloud:gcs_file_system", + "@tsl//tsl/platform:env", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 28e09be6c68..a1223c5ef7e 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -291,20 +291,13 @@ class ComputationClient { const ExecuteComputationOptions& options = ExecuteComputationOptions{}) = 0; - // Executes the computation in replicated mode. - // The size of the arguments vector is the number of replicas to execute, - // and it must match the size of the computation.devices() as well as the - // devices passed as argument. The destination devices for each replicated - // computation come from the devices the Data objects are stored into, which - // must match the devices argument. Within arguments[i], every Data - // object must be coming from the same device. Returns a vector (of the same - // size of the arguments vector) with the results of the parallel execution. - // The result[i], a vector itself, will be the result of the computation fed - // with arguments[i]. If options.explode_tuple is true, the output tuples will - // be decomposed into their single elements. - virtual std::vector> ExecuteReplicated( - const Computation& computation, - const std::vector>& arguments, + // Executes the computation on multiple local devices in parallel. + // Each argument to the executable is expected to be sharded in the same order + // as `devices`. If options.explode_tuple is true, the output tuples will be + // decomposed into their single elements. Returns a vector of outputs, each + // of which is sharded in the same order as `devices`. + virtual std::vector ExecuteReplicated( + const Computation& computation, absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) = 0; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 1aa017bc33d..0bd42b3ad6a 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -1,6 +1,7 @@ #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include +#include #include #include @@ -386,13 +387,16 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice( std::move(status_or.value())); } -ComputationClient::DataPtr PjRtComputationClient::ReplicateShardedData( +std::shared_ptr +PjRtComputationClient::ReplicateShardedData( const ComputationClient::DataPtr& handle) { - if (PjRtShardedData* sharded_data = - dynamic_cast(handle.get())) { + if (auto unsharded_data = std::dynamic_pointer_cast(handle)) { + return unsharded_data; + } else if (auto sharded_data = + std::dynamic_pointer_cast(handle)) { XLA_COUNTER("ReplicateShardedData", 1); - TF_VLOG(1) << "ReplicateShardedData (handle=" << handle->GetHandle() - << ", shape=" << handle->shape() << ")"; + TF_VLOG(1) << "ReplicateShardedData (handle=" << sharded_data->GetHandle() + << ", shape=" << sharded_data->shape() << ")"; if (sharded_data->GetSharding().type() == xla::OpSharding::REPLICATED) { // Data is replicated, return the first shard return sharded_data->shards[0]; @@ -428,35 +432,22 @@ ComputationClient::DataPtr PjRtComputationClient::ReplicateShardedData( std::shared_ptr> computations = Compile(std::move(instances)); - auto shards = sharded_data->shards; - XLA_CHECK_EQ(shards.size(), GetLocalDevices().size()); - auto device_index = build_index_map(GetLocalDevices()); - - std::vector> arguments_by_device( - GetLocalDevices().size(), std::vector(1)); - for (auto shard : shards) { - std::vector device_spec = - absl::StrSplit(shard->device(), ':'); - XLA_CHECK_EQ(device_spec.size(), 2) - << "Invalid device specification: " << shard->device(); - int device_i = device_index[std::stoi(device_spec[1])]; - TF_VLOG(3) << shard->device() << " is mapped to local device index " - << device_i; - arguments_by_device[device_i][0] = shard; - } torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; auto sharded_results = - ExecuteReplicated(*computations.front(), arguments_by_device, + ExecuteReplicated(*computations.front(), {sharded_data}, GetLocalDevices(), execute_options); XLA_CHECK(sharded_results.size() > 0) << "empty ExecuteReplicated results returned."; - XLA_CHECK(sharded_results[0].size() == 1) + XLA_CHECK(sharded_results.size() == 1) << "Wrong number of outputs, expected: 1, actual: " - << sharded_results[0].size(); - return sharded_results[0][0]; + << sharded_results.size(); + return std::dynamic_pointer_cast(sharded_results[0]) + ->shards[0]; } - return handle; + + XLA_ERROR() << "Data must be PjRtData or PjRtShardedData, got " + << handle->ToString(); } std::vector PjRtComputationClient::TransferFromServer( @@ -472,12 +463,12 @@ std::vector PjRtComputationClient::TransferFromServer( for (auto handle : handles) { // Use XLA replication to reassemble the sharded data. If input handle // is not sharded, then it is a no-op. - auto new_handle = ReplicateShardedData(handle); - const PjRtData& pjrt_data = dynamic_cast(*new_handle); + std::shared_ptr pjrt_data = ReplicateShardedData(handle); + XLA_CHECK(pjrt_data); xla::Literal& literal = - literals.emplace_back(host_output_shape(pjrt_data.buffer.get())); - futures.push_back(pjrt_data.buffer->ToLiteral(&literal)); + literals.emplace_back(host_output_shape(pjrt_data->buffer.get())); + futures.push_back(pjrt_data->buffer->ToLiteral(&literal)); total_size += literal.size_bytes(); } @@ -642,10 +633,10 @@ PjRtComputationClient::ExecuteComputation( return datas; } -std::vector> +std::vector PjRtComputationClient::ExecuteReplicated( const ComputationClient::Computation& computation, - const std::vector>& arguments, + absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) { // Shared ownership of the timed section ensures that it will only get logged @@ -657,34 +648,41 @@ PjRtComputationClient::ExecuteReplicated( tsl::profiler::TraceMeLevel::kInfo); const PjRtComputation& pjrt_computation = dynamic_cast(computation); - XLA_CHECK(devices.size() == arguments.size()) - << "ExecuteReplicated over " << devices.size() << " devices, but " - << arguments.size() << " arguments devices."; - absl::BlockingCounter counter(devices.size()); - std::vector> argument_handles(devices.size()); + + std::vector> argument_handles( + devices.size(), std::vector(arguments.size())); { tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_argument_handle", tsl::profiler::TraceMeLevel::kInfo); - for (int32_t i = 0; i < devices.size(); ++i) { - auto buffer_converter = [&, i]() { - xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); - XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); - - std::vector buffers; - for (auto& argument : arguments[i]) { - const PjRtData* pjrt_data = dynamic_cast(argument.get()); - - XLA_CHECK(pjrt_device == pjrt_data->buffer->device()) - << pjrt_device->DebugString() << " vs " - << pjrt_data->buffer->device()->DebugString(); - buffers.push_back(pjrt_data->buffer.get()); - } - argument_handles[i] = std::move(buffers); - counter.DecrementCount(); - }; - thread::Schedule(std::move(buffer_converter)); - } + + absl::BlockingCounter counter(arguments.size()); + + // Time in nanoseconds that it takes to prepare an argument. Used to tune + // number of threads spawned by ParallelFor. Measured on 2023/11/28. + static constexpr int64_t argument_handle_cost_ns = 10000; + pool_.ParallelFor( + arguments.size(), argument_handle_cost_ns, + [&](int64_t start, int64_t end) { + for (int32_t i = start; i < end; ++i) { + auto pjrt_data = + std::dynamic_pointer_cast(arguments[i]); + XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size()) + << "Expected one shard per device"; + + for (int32_t d = 0; d < devices.size(); d++) { + std::shared_ptr shard = pjrt_data->shards[d]; + + xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[d]); + XLA_CHECK_EQ(shard->buffer->device(), pjrt_device); + XLA_CHECK(pjrt_device->IsAddressable()) + << pjrt_device->DebugString(); + + argument_handles[d][i] = shard->buffer.get(); + } + counter.DecrementCount(); + }; + }); counter.Wait(); } @@ -726,38 +724,58 @@ PjRtComputationClient::ExecuteReplicated( })); } - std::vector> data_handles; - data_handles.reserve(results.size()); - std::vector dims(results.size()); + size_t num_outputs = results[0].size(); + std::vector data_handles(num_outputs); { tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_result_handle", tsl::profiler::TraceMeLevel::kInfo); - for (int32_t i = 0; i < results.size(); ++i) { - xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); - XLA_CHECK(pjrt_device->IsAddressable()) - << pjrt_device->DebugString() << " is not addressable."; - - std::vector datas; - datas.reserve(results[i].size()); - dims[i] = results[i].size(); - for (int32_t j = 0; j < results[i].size(); ++j) { - std::unique_ptr buffer = std::move(results[i][j]); - XLA_CHECK(pjrt_device == buffer->device()) - << "Exepcted device: " << pjrt_device->DebugString() - << " vs. actual device: " << buffer->device()->DebugString(); - - std::shared_ptr data = - std::make_shared(devices[i], std::move(buffer)); - datas.push_back(data); - } - data_handles.push_back(datas); - } + + const xla::Shape& result_shape = computation.program_shape().result(); + TF_VLOG(3) << "Processing output with shape " << result_shape.ToString(); + const std::vector& output_shapes = + result_shape.IsTuple() ? result_shape.tuple_shapes() + : std::vector({result_shape}); + XLA_CHECK_EQ(output_shapes.size(), num_outputs); + + const std::vector& output_shardings = + pjrt_computation.output_shardings_ + ? *pjrt_computation.output_shardings_ + : + // Without an explicit sharding annotation, the output is implicitly + // replicated + std::vector(num_outputs, xla::HloSharding::Replicate().ToProto()); + XLA_CHECK_EQ(output_shardings.size(), num_outputs); + + absl::BlockingCounter counter(num_outputs); + + // Time in nanoseconds that it takes to process a result buffer. + // Measured on 2023/11/28. + static constexpr int64_t result_handle_cost_ns = 10000; + pool_.ParallelFor( + num_outputs, result_handle_cost_ns, [&](int64_t start, int64_t end) { + for (int32_t i = start; i < end; ++i) { + std::vector> shards(devices.size()); + for (int32_t d = 0; d < devices.size(); d++) { + std::unique_ptr buffer = + std::move(results[d][i]); + shards[d] = + std::make_shared(devices[d], std::move(buffer)); + } + + data_handles[i] = std::make_shared( + spmd_device_str, output_shapes[i], std::move(shards), + output_shardings[i]); + TF_VLOG(5) << "Created sharded data with shape " + << data_handles[i]->shape().ToString(); + counter.DecrementCount(); + } + }); + counter.Wait(); } - TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results " - << "with dimensions [" << absl::StrJoin(dims, ",") << "]."; + TF_VLOG(1) << "Returning " << data_handles.size() << " sharded outputs."; return data_handles; } diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index e5eaf4039ef..93760ee6b05 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -12,6 +12,8 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/operation_manager.h" #include "torch_xla/csrc/runtime/util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/threadpool.h" #include "xla/client/xla_computation.h" #include "xla/literal.h" #include "xla/pjrt/pjrt_client.h" @@ -40,9 +42,6 @@ class PjRtComputationClient : public ComputationClient { std::vector TransferToServer( absl::Span> tensors) override; - // Use XLA replication to re-assemble the sharded data. - DataPtr ReplicateShardedData(const DataPtr& handle); - std::vector TransferFromServer( absl::Span handles) override; @@ -60,9 +59,8 @@ class PjRtComputationClient : public ComputationClient { const std::string& device, const ExecuteComputationOptions& options) override; - std::vector> ExecuteReplicated( - const Computation& computation, - const std::vector>& arguments, + std::vector ExecuteReplicated( + const Computation& computation, absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) override; @@ -114,6 +112,8 @@ class PjRtComputationClient : public ComputationClient { std::unordered_map string_to_device_; std::shared_ptr> replication_devices_; OperationManager operation_manager_; + tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool( + tsl::Env::Default(), "pjrt", std::thread::hardware_concurrency()); xla::PjRtDevice* StringToPjRtDevice(const std::string& device); @@ -231,10 +231,16 @@ class PjRtComputationClient : public ComputationClient { std::vector devices, std::unique_ptr executable) : Computation(std::move(computation), std::move(devices)), - executable(std::move(executable)) {} + executable(std::move(executable)) { + output_shardings_ = this->executable->GetOutputShardings(); + } std::unique_ptr executable; + std::optional> output_shardings_; }; + + // Use XLA replication to re-assemble the sharded data. + std::shared_ptr ReplicateShardedData(const DataPtr& handle); }; } // namespace runtime diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index b02ce64fd9c..2b193bdbb96 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -705,19 +705,15 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( if (async->cached_computation->is_sharded) { std::vector devices = runtime::GetComputationClient()->GetLocalDevices(); - std::vector> - device_arguments = ShardingUtil::InputHandler( - UnwrapXlaData(async->parameters_data), devices); runtime::ComputationClient::ExecuteReplicatedOptions execute_options; // OutputHandler creates sharded data for sharded // tensor results. Both sharded and unsharded results should be // "Assign"ed to the corresponding data placeholders. std::vector outputs = - ShardingUtil::OutputHandler( - runtime::GetComputationClient()->ExecuteReplicated( - *async->cached_computation->computation, device_arguments, - devices, execute_options), - sharding_specs); + runtime::GetComputationClient()->ExecuteReplicated( + *async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), devices, + execute_options); results = WrapXlaData(outputs); TF_VLOG(3) << "Executing Dynamo IR sharded graph hash " << torch::lazy::HashToString(hash) << " on devices " @@ -973,9 +969,6 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( if (async->cached_computation->is_sharded) { std::vector devices = runtime::GetComputationClient()->GetLocalDevices(); - std::vector> - device_arguments = ShardingUtil::InputHandler( - UnwrapXlaData(async->parameters_data), devices); runtime::ComputationClient::ExecuteReplicatedOptions execute_options; TF_VLOG(3) << "Executing IR graph hash " << torch::lazy::HashToString(hash) @@ -984,11 +977,10 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( // tensor results. Both sharded and unsharded results should be // "Assign"ed to the corresponding data placeholders. std::vector outputs = - ShardingUtil::OutputHandler( - runtime::GetComputationClient()->ExecuteReplicated( - *async->cached_computation->computation, device_arguments, - devices, execute_options), - sharding_specs, /*replicated_output=*/false); + runtime::GetComputationClient()->ExecuteReplicated( + *async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), devices, + execute_options); results = WrapXlaData(outputs); TF_VLOG(3) << "Executing IR graph hash " << torch::lazy::HashToString(hash) diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 8a6caa97992..6fbe2f80574 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -340,86 +340,6 @@ xla::HloModuleProto ShardingUtil::SpmdPartitioningPass( return module.get()->ToProto(); } -std::vector> -ShardingUtil::InputHandler( - std::vector arguments, - std::vector devices) { - tsl::profiler::TraceMe activity("InputHandler", - tsl::profiler::TraceMeLevel::kInfo); - std::vector> - arguments_by_device( - devices.size(), - std::vector(arguments.size())); - // This assumes that the (local) devices are sorted, in order to associate - // the first local index with the first global device ordinal. - auto device_index = build_index_map(devices); - - absl::BlockingCounter counter(devices.size()); - - for (int i = 0; i < devices.size(); i++) { - auto argument_setter = [&, i]() { - for (int64_t argument_i = 0; argument_i < arguments.size(); - ++argument_i) { - runtime::ComputationClient::DataPtr shard = - runtime::GetComputationClient()->GetDataShard(arguments[argument_i], - i); - int global_ordinal = ParseDeviceString(shard->device()).ordinal(); - int device_i = device_index[global_ordinal]; - arguments_by_device[device_i][argument_i] = shard; - } - counter.DecrementCount(); - }; - thread::Schedule(std::move(argument_setter)); - } - counter.Wait(); - return arguments_by_device; -} - -std::vector ShardingUtil::OutputHandler( - std::vector> - sharded_results, - std::vector sharding_specs, - bool replicated_output) { - tsl::profiler::TraceMe activity("OutputHandler", - tsl::profiler::TraceMeLevel::kInfo); - std::vector outputs; - outputs.reserve(sharding_specs.size()); - for (int i = 0; i < sharding_specs.size(); ++i) { - XLATensor::ShardingSpecPtr sharding = sharding_specs[i]; - if (replicated_output && sharding && - (sharding->sharding.type() != xla::OpSharding::REPLICATED)) { - // Reshards replicated output if `sharding` is present. - std::vector tensors = XlaDataToTensors( - {sharded_results[0][i]}, - MaybeUpcastToHostTorchType(sharding->shape.element_type())); - outputs.push_back( - std::dynamic_pointer_cast( - CreateTensorsData( - tensors, {sharding}, - std::vector{GetVirtualDevice().toString()})[0])); - } else { - // The output is sharded or replicated. - std::vector shards; - shards.reserve(sharded_results.size()); - for (int j = 0; j < sharded_results.size(); ++j) { - XLA_CHECK(sharded_results[j][i]->HasValue()); - shards.push_back(sharded_results[j][i]); - } - if (!sharding) { - // Without an explicit sharding annotation, the output is implicitly - // replicated - sharding = std::make_shared( - xla::HloSharding::Replicate().ToProto(), - sharded_results[0][i]->shape()); - } - outputs.push_back(runtime::GetComputationClient()->WrapDataShards( - shards, GetVirtualDevice().toString(), sharding->shape, - sharding->sharding)); - } - } - return outputs; -} - std::vector ShardingUtil::GetShardShape( const XLATensor::ShardingSpecPtr shardings) { auto sharding = shardings->sharding; diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index f6846664790..848b1e62db1 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -58,32 +58,6 @@ class ShardingUtil { bool unroll_windowed_einsum = false, bool bidirectional_windowed_einsum = false); - // Reshuffles arguments (sharded or replicated) on the devices. The - // size of the arguments vector must match that of the sharding_specs. - // The the returned arguments will be in 1:1 correspondence with the `devices` - // vector, so the `i`th result will belong on the `i`th device. - // TODO(yeounoh) avoiding pre-loading of the unpartitioned input arguments - // might improve the performance and save the bandwidth. - static std::vector> - InputHandler(std::vector arguments, - std::vector devices); - - // Processes replicated execution results, where `sharded_results` contains - // `PjRtData` handles and spans the number of devices (outer) and the number - // of arguments (innner). This requires `sharding_specs` of the same size as - // the number of arguments. `sharding_specs` can contain `nullptr` if the - // corresponding result argument is not sharded. The replicated execution - // `replicated_output=true` leaves the results in replicated states, which is - // aligned with the default exepctation of XLA compiler. However, we override - // the compiler's default behavior and allow the execution to return sharded - // results and wrap sharded arguments into `PjRtShardedData`. This returns a - // vector of size that is equal to the number of arguments. - static std::vector OutputHandler( - std::vector> - sharded_results, - std::vector sharding_specs, - bool replicated_output = false); - // Returns the shape of the resulting shards of `tensor` after applying // `sharding`. This assumes the shards will be padded to ensure they all // have the same shape.