From 7eaa91ba58435f52c13958ff09aa220aecba8873 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 25 Oct 2023 22:49:27 +0000 Subject: [PATCH 01/24] Refactor ExecuteReplicated to operate on sharded data directly --- torch_xla/csrc/runtime/computation_client.h | 4 +- .../csrc/runtime/pjrt_computation_client.cc | 115 +++++++++--------- .../csrc/runtime/pjrt_computation_client.h | 10 +- torch_xla/csrc/xla_graph_executor.cpp | 18 +-- 4 files changed, 67 insertions(+), 80 deletions(-) diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 28e09be6c68..3ef9629ed9e 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -302,9 +302,9 @@ class ComputationClient { // The result[i], a vector itself, will be the result of the computation fed // with arguments[i]. If options.explode_tuple is true, the output tuples will // be decomposed into their single elements. - virtual std::vector> ExecuteReplicated( + virtual std::vector ExecuteReplicated( const Computation& computation, - const std::vector>& arguments, + absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) = 0; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 1aa017bc33d..3d2850b81d9 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -386,7 +386,7 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice( std::move(status_or.value())); } -ComputationClient::DataPtr PjRtComputationClient::ReplicateShardedData( +std::shared_ptr PjRtComputationClient::ReplicateShardedData( const ComputationClient::DataPtr& handle) { if (PjRtShardedData* sharded_data = dynamic_cast(handle.get())) { @@ -428,35 +428,21 @@ ComputationClient::DataPtr PjRtComputationClient::ReplicateShardedData( std::shared_ptr> computations = Compile(std::move(instances)); - auto shards = sharded_data->shards; - XLA_CHECK_EQ(shards.size(), GetLocalDevices().size()); - auto device_index = build_index_map(GetLocalDevices()); - - std::vector> arguments_by_device( - GetLocalDevices().size(), std::vector(1)); - for (auto shard : shards) { - std::vector device_spec = - absl::StrSplit(shard->device(), ':'); - XLA_CHECK_EQ(device_spec.size(), 2) - << "Invalid device specification: " << shard->device(); - int device_i = device_index[std::stoi(device_spec[1])]; - TF_VLOG(3) << shard->device() << " is mapped to local device index " - << device_i; - arguments_by_device[device_i][0] = shard; - } torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; auto sharded_results = - ExecuteReplicated(*computations.front(), arguments_by_device, + ExecuteReplicated(*computations.front(), {handle}, GetLocalDevices(), execute_options); XLA_CHECK(sharded_results.size() > 0) << "empty ExecuteReplicated results returned."; - XLA_CHECK(sharded_results[0].size() == 1) + XLA_CHECK(sharded_results.size() == 1) << "Wrong number of outputs, expected: 1, actual: " - << sharded_results[0].size(); - return sharded_results[0][0]; + << sharded_results.size(); + return std::dynamic_pointer_cast(sharded_results[0])->shards[0]; } - return handle; + auto pjrt_data = std::dynamic_pointer_cast(handle); + XLA_CHECK(pjrt_data) << "Data must be PjRtData or PjRtShardedData."; + return pjrt_data; } std::vector PjRtComputationClient::TransferFromServer( @@ -472,8 +458,8 @@ std::vector PjRtComputationClient::TransferFromServer( for (auto handle : handles) { // Use XLA replication to reassemble the sharded data. If input handle // is not sharded, then it is a no-op. - auto new_handle = ReplicateShardedData(handle); - const PjRtData& pjrt_data = dynamic_cast(*new_handle); + std::shared_ptr pjrt_data = ReplicateShardedData(handle); + XLA_CHECK(pjrt_data); xla::Literal& literal = literals.emplace_back(host_output_shape(pjrt_data.buffer.get())); @@ -642,10 +628,10 @@ PjRtComputationClient::ExecuteComputation( return datas; } -std::vector> +std::vector PjRtComputationClient::ExecuteReplicated( const ComputationClient::Computation& computation, - const std::vector>& arguments, + absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) { // Shared ownership of the timed section ensures that it will only get logged @@ -657,30 +643,29 @@ PjRtComputationClient::ExecuteReplicated( tsl::profiler::TraceMeLevel::kInfo); const PjRtComputation& pjrt_computation = dynamic_cast(computation); - XLA_CHECK(devices.size() == arguments.size()) - << "ExecuteReplicated over " << devices.size() << " devices, but " - << arguments.size() << " arguments devices."; - absl::BlockingCounter counter(devices.size()); - std::vector> argument_handles(devices.size()); + // XLA_CHECK(devices.size() == arguments.size()) + // << "ExecuteReplicated over " << devices.size() << " devices, but " + // << arguments.size() << " arguments devices."; + absl::BlockingCounter counter(arguments.size()); + std::vector> argument_handles(devices.size(), std::vector(arguments.size())); { tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_argument_handle", tsl::profiler::TraceMeLevel::kInfo); - for (int32_t i = 0; i < devices.size(); ++i) { + for (int32_t i = 0; i < arguments.size(); ++i) { auto buffer_converter = [&, i]() { - xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); - XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); + auto pjrt_data = std::dynamic_pointer_cast(arguments[i]); + XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size()) << "Expected one shard per device"; - std::vector buffers; - for (auto& argument : arguments[i]) { - const PjRtData* pjrt_data = dynamic_cast(argument.get()); + for (int32_t d = 0; d < devices.size(); d++) { + std::shared_ptr shard = pjrt_data->shards[d]; - XLA_CHECK(pjrt_device == pjrt_data->buffer->device()) - << pjrt_device->DebugString() << " vs " - << pjrt_data->buffer->device()->DebugString(); - buffers.push_back(pjrt_data->buffer.get()); + xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[d]); + XLA_CHECK_EQ(shard->buffer->device(), pjrt_device); + XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); + + argument_handles[d][i] = shard->buffer.get(); } - argument_handles[i] = std::move(buffers); counter.DecrementCount(); }; thread::Schedule(std::move(buffer_converter)); @@ -726,7 +711,7 @@ PjRtComputationClient::ExecuteReplicated( })); } - std::vector> data_handles; + std::vector data_handles; data_handles.reserve(results.size()); std::vector dims(results.size()); @@ -734,30 +719,42 @@ PjRtComputationClient::ExecuteReplicated( tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_result_handle", tsl::profiler::TraceMeLevel::kInfo); - for (int32_t i = 0; i < results.size(); ++i) { - xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); - XLA_CHECK(pjrt_device->IsAddressable()) - << pjrt_device->DebugString() << " is not addressable."; - - std::vector datas; - datas.reserve(results[i].size()); - dims[i] = results[i].size(); - for (int32_t j = 0; j < results[i].size(); ++j) { + size_t num_returns = results[0].size(); + + std::vector> hlo_modules = pjrt_computation.executable->GetHloModules().value(); + XLA_CHECK_EQ(hlo_modules.size(), 1) << "Expected one HLO module with multiple outputs."; + const xla::Shape& result_shape = hlo_modules[0]->result_shape(); + const std::vector& output_shapes = result_shape.IsTuple() ? hlo_modules[0]->result_shape().tuple_shapes() : std::vector({result_shape}); + XLA_CHECK_EQ(output_shapes.size(), num_returns) << "Output shape: " << result_shape.ToString(); + + // TODO(wcromar): Implement this in PJRT C API client + // std::vector output_shapes = pjrt_computation.executable->GetOutputShapes().value(); + std::vector output_shardings = pjrt_computation.executable->GetOutputShardings().value(); + + XLA_CHECK_EQ(output_shardings.size(), num_returns); + + for (int32_t j = 0; j < num_returns; j++) { + std::vector> shards(devices.size()); + + for (int32_t i = 0; i < devices.size(); i++) { + XLA_CHECK_EQ(results[i].size(), num_returns); + std::unique_ptr buffer = std::move(results[i][j]); + xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); XLA_CHECK(pjrt_device == buffer->device()) - << "Exepcted device: " << pjrt_device->DebugString() + << "Exepected device: " << pjrt_device->DebugString() << " vs. actual device: " << buffer->device()->DebugString(); - std::shared_ptr data = - std::make_shared(devices[i], std::move(buffer)); - datas.push_back(data); + shards[i] = std::make_shared(devices[i], std::move(buffer)); } - data_handles.push_back(datas); + + auto data = + std::make_shared(spmd_device_str, output_shapes[j], std::move(shards), output_shardings[j]); + data_handles.push_back(data); } } - TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results " - << "with dimensions [" << absl::StrJoin(dims, ",") << "]."; + TF_VLOG(1) << "Returning " << data_handles.size() << " sharded outputs."; return data_handles; } diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index e5eaf4039ef..cd7137029fd 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -40,9 +40,6 @@ class PjRtComputationClient : public ComputationClient { std::vector TransferToServer( absl::Span> tensors) override; - // Use XLA replication to re-assemble the sharded data. - DataPtr ReplicateShardedData(const DataPtr& handle); - std::vector TransferFromServer( absl::Span handles) override; @@ -60,9 +57,9 @@ class PjRtComputationClient : public ComputationClient { const std::string& device, const ExecuteComputationOptions& options) override; - std::vector> ExecuteReplicated( + std::vector ExecuteReplicated( const Computation& computation, - const std::vector>& arguments, + absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) override; @@ -235,6 +232,9 @@ class PjRtComputationClient : public ComputationClient { std::unique_ptr executable; }; + + // Use XLA replication to re-assemble the sharded data. + std::shared_ptr ReplicateShardedData(const DataPtr& handle); }; } // namespace runtime diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index b02ce64fd9c..6ab15fc96ab 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -705,19 +705,14 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( if (async->cached_computation->is_sharded) { std::vector devices = runtime::GetComputationClient()->GetLocalDevices(); - std::vector> - device_arguments = ShardingUtil::InputHandler( - UnwrapXlaData(async->parameters_data), devices); runtime::ComputationClient::ExecuteReplicatedOptions execute_options; // OutputHandler creates sharded data for sharded // tensor results. Both sharded and unsharded results should be // "Assign"ed to the corresponding data placeholders. std::vector outputs = - ShardingUtil::OutputHandler( runtime::GetComputationClient()->ExecuteReplicated( - *async->cached_computation->computation, device_arguments, - devices, execute_options), - sharding_specs); + *async->cached_computation->computation, UnwrapXlaData(async->parameters_data), + devices, execute_options); results = WrapXlaData(outputs); TF_VLOG(3) << "Executing Dynamo IR sharded graph hash " << torch::lazy::HashToString(hash) << " on devices " @@ -973,9 +968,6 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( if (async->cached_computation->is_sharded) { std::vector devices = runtime::GetComputationClient()->GetLocalDevices(); - std::vector> - device_arguments = ShardingUtil::InputHandler( - UnwrapXlaData(async->parameters_data), devices); runtime::ComputationClient::ExecuteReplicatedOptions execute_options; TF_VLOG(3) << "Executing IR graph hash " << torch::lazy::HashToString(hash) @@ -984,11 +976,9 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( // tensor results. Both sharded and unsharded results should be // "Assign"ed to the corresponding data placeholders. std::vector outputs = - ShardingUtil::OutputHandler( runtime::GetComputationClient()->ExecuteReplicated( - *async->cached_computation->computation, device_arguments, - devices, execute_options), - sharding_specs, /*replicated_output=*/false); + *async->cached_computation->computation, UnwrapXlaData(async->parameters_data), + devices, execute_options); results = WrapXlaData(outputs); TF_VLOG(3) << "Executing IR graph hash " << torch::lazy::HashToString(hash) From 0cc6dcd99a087ab58dfe7e73daa556886f905d13 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 25 Oct 2023 22:51:59 +0000 Subject: [PATCH 02/24] Remove old handlers --- torch_xla/csrc/xla_sharding_util.cpp | 80 ---------------------------- 1 file changed, 80 deletions(-) diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 8a6caa97992..6fbe2f80574 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -340,86 +340,6 @@ xla::HloModuleProto ShardingUtil::SpmdPartitioningPass( return module.get()->ToProto(); } -std::vector> -ShardingUtil::InputHandler( - std::vector arguments, - std::vector devices) { - tsl::profiler::TraceMe activity("InputHandler", - tsl::profiler::TraceMeLevel::kInfo); - std::vector> - arguments_by_device( - devices.size(), - std::vector(arguments.size())); - // This assumes that the (local) devices are sorted, in order to associate - // the first local index with the first global device ordinal. - auto device_index = build_index_map(devices); - - absl::BlockingCounter counter(devices.size()); - - for (int i = 0; i < devices.size(); i++) { - auto argument_setter = [&, i]() { - for (int64_t argument_i = 0; argument_i < arguments.size(); - ++argument_i) { - runtime::ComputationClient::DataPtr shard = - runtime::GetComputationClient()->GetDataShard(arguments[argument_i], - i); - int global_ordinal = ParseDeviceString(shard->device()).ordinal(); - int device_i = device_index[global_ordinal]; - arguments_by_device[device_i][argument_i] = shard; - } - counter.DecrementCount(); - }; - thread::Schedule(std::move(argument_setter)); - } - counter.Wait(); - return arguments_by_device; -} - -std::vector ShardingUtil::OutputHandler( - std::vector> - sharded_results, - std::vector sharding_specs, - bool replicated_output) { - tsl::profiler::TraceMe activity("OutputHandler", - tsl::profiler::TraceMeLevel::kInfo); - std::vector outputs; - outputs.reserve(sharding_specs.size()); - for (int i = 0; i < sharding_specs.size(); ++i) { - XLATensor::ShardingSpecPtr sharding = sharding_specs[i]; - if (replicated_output && sharding && - (sharding->sharding.type() != xla::OpSharding::REPLICATED)) { - // Reshards replicated output if `sharding` is present. - std::vector tensors = XlaDataToTensors( - {sharded_results[0][i]}, - MaybeUpcastToHostTorchType(sharding->shape.element_type())); - outputs.push_back( - std::dynamic_pointer_cast( - CreateTensorsData( - tensors, {sharding}, - std::vector{GetVirtualDevice().toString()})[0])); - } else { - // The output is sharded or replicated. - std::vector shards; - shards.reserve(sharded_results.size()); - for (int j = 0; j < sharded_results.size(); ++j) { - XLA_CHECK(sharded_results[j][i]->HasValue()); - shards.push_back(sharded_results[j][i]); - } - if (!sharding) { - // Without an explicit sharding annotation, the output is implicitly - // replicated - sharding = std::make_shared( - xla::HloSharding::Replicate().ToProto(), - sharded_results[0][i]->shape()); - } - outputs.push_back(runtime::GetComputationClient()->WrapDataShards( - shards, GetVirtualDevice().toString(), sharding->shape, - sharding->sharding)); - } - } - return outputs; -} - std::vector ShardingUtil::GetShardShape( const XLATensor::ShardingSpecPtr shardings) { auto sharding = shardings->sharding; From e9be48fbd4f8c850812e5030ba1ceb564841679e Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 25 Oct 2023 22:53:42 +0000 Subject: [PATCH 03/24] formatting --- torch_xla/csrc/runtime/computation_client.h | 3 +- .../csrc/runtime/pjrt_computation_client.cc | 48 +++++++++++-------- .../csrc/runtime/pjrt_computation_client.h | 3 +- torch_xla/csrc/xla_graph_executor.cpp | 14 +++--- 4 files changed, 39 insertions(+), 29 deletions(-) diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 3ef9629ed9e..cfd06aef94e 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -303,8 +303,7 @@ class ComputationClient { // with arguments[i]. If options.explode_tuple is true, the output tuples will // be decomposed into their single elements. virtual std::vector ExecuteReplicated( - const Computation& computation, - absl::Span arguments, + const Computation& computation, absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) = 0; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 3d2850b81d9..456b83ceafe 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -386,7 +386,8 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice( std::move(status_or.value())); } -std::shared_ptr PjRtComputationClient::ReplicateShardedData( +std::shared_ptr +PjRtComputationClient::ReplicateShardedData( const ComputationClient::DataPtr& handle) { if (PjRtShardedData* sharded_data = dynamic_cast(handle.get())) { @@ -430,15 +431,15 @@ std::shared_ptr PjRtComputationClient::Replicat torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; - auto sharded_results = - ExecuteReplicated(*computations.front(), {handle}, - GetLocalDevices(), execute_options); + auto sharded_results = ExecuteReplicated( + *computations.front(), {handle}, GetLocalDevices(), execute_options); XLA_CHECK(sharded_results.size() > 0) << "empty ExecuteReplicated results returned."; XLA_CHECK(sharded_results.size() == 1) << "Wrong number of outputs, expected: 1, actual: " << sharded_results.size(); - return std::dynamic_pointer_cast(sharded_results[0])->shards[0]; + return std::dynamic_pointer_cast(sharded_results[0]) + ->shards[0]; } auto pjrt_data = std::dynamic_pointer_cast(handle); XLA_CHECK(pjrt_data) << "Data must be PjRtData or PjRtShardedData."; @@ -643,19 +644,20 @@ PjRtComputationClient::ExecuteReplicated( tsl::profiler::TraceMeLevel::kInfo); const PjRtComputation& pjrt_computation = dynamic_cast(computation); - // XLA_CHECK(devices.size() == arguments.size()) - // << "ExecuteReplicated over " << devices.size() << " devices, but " - // << arguments.size() << " arguments devices."; + absl::BlockingCounter counter(arguments.size()); - std::vector> argument_handles(devices.size(), std::vector(arguments.size())); + std::vector> argument_handles( + devices.size(), std::vector(arguments.size())); { tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_argument_handle", tsl::profiler::TraceMeLevel::kInfo); for (int32_t i = 0; i < arguments.size(); ++i) { auto buffer_converter = [&, i]() { - auto pjrt_data = std::dynamic_pointer_cast(arguments[i]); - XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size()) << "Expected one shard per device"; + auto pjrt_data = + std::dynamic_pointer_cast(arguments[i]); + XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size()) + << "Expected one shard per device"; for (int32_t d = 0; d < devices.size(); d++) { std::shared_ptr shard = pjrt_data->shards[d]; @@ -721,15 +723,22 @@ PjRtComputationClient::ExecuteReplicated( tsl::profiler::TraceMeLevel::kInfo); size_t num_returns = results[0].size(); - std::vector> hlo_modules = pjrt_computation.executable->GetHloModules().value(); - XLA_CHECK_EQ(hlo_modules.size(), 1) << "Expected one HLO module with multiple outputs."; + std::vector> hlo_modules = + pjrt_computation.executable->GetHloModules().value(); + XLA_CHECK_EQ(hlo_modules.size(), 1) + << "Expected one HLO module with multiple outputs."; const xla::Shape& result_shape = hlo_modules[0]->result_shape(); - const std::vector& output_shapes = result_shape.IsTuple() ? hlo_modules[0]->result_shape().tuple_shapes() : std::vector({result_shape}); - XLA_CHECK_EQ(output_shapes.size(), num_returns) << "Output shape: " << result_shape.ToString(); + const std::vector& output_shapes = + result_shape.IsTuple() ? hlo_modules[0]->result_shape().tuple_shapes() + : std::vector({result_shape}); + XLA_CHECK_EQ(output_shapes.size(), num_returns) + << "Output shape: " << result_shape.ToString(); // TODO(wcromar): Implement this in PJRT C API client - // std::vector output_shapes = pjrt_computation.executable->GetOutputShapes().value(); - std::vector output_shardings = pjrt_computation.executable->GetOutputShardings().value(); + // std::vector output_shapes = + // pjrt_computation.executable->GetOutputShapes().value(); + std::vector output_shardings = + pjrt_computation.executable->GetOutputShardings().value(); XLA_CHECK_EQ(output_shardings.size(), num_returns); @@ -748,8 +757,9 @@ PjRtComputationClient::ExecuteReplicated( shards[i] = std::make_shared(devices[i], std::move(buffer)); } - auto data = - std::make_shared(spmd_device_str, output_shapes[j], std::move(shards), output_shardings[j]); + auto data = std::make_shared( + spmd_device_str, output_shapes[j], std::move(shards), + output_shardings[j]); data_handles.push_back(data); } } diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index cd7137029fd..77866953ed0 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -58,8 +58,7 @@ class PjRtComputationClient : public ComputationClient { const ExecuteComputationOptions& options) override; std::vector ExecuteReplicated( - const Computation& computation, - absl::Span arguments, + const Computation& computation, absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) override; diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 6ab15fc96ab..2b193bdbb96 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -710,9 +710,10 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( // tensor results. Both sharded and unsharded results should be // "Assign"ed to the corresponding data placeholders. std::vector outputs = - runtime::GetComputationClient()->ExecuteReplicated( - *async->cached_computation->computation, UnwrapXlaData(async->parameters_data), - devices, execute_options); + runtime::GetComputationClient()->ExecuteReplicated( + *async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), devices, + execute_options); results = WrapXlaData(outputs); TF_VLOG(3) << "Executing Dynamo IR sharded graph hash " << torch::lazy::HashToString(hash) << " on devices " @@ -976,9 +977,10 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( // tensor results. Both sharded and unsharded results should be // "Assign"ed to the corresponding data placeholders. std::vector outputs = - runtime::GetComputationClient()->ExecuteReplicated( - *async->cached_computation->computation, UnwrapXlaData(async->parameters_data), - devices, execute_options); + runtime::GetComputationClient()->ExecuteReplicated( + *async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), devices, + execute_options); results = WrapXlaData(outputs); TF_VLOG(3) << "Executing IR graph hash " << torch::lazy::HashToString(hash) From d3e3976b0248f46c4778372011ce33e44de235fa Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 26 Oct 2023 18:53:27 +0000 Subject: [PATCH 04/24] Improve naming and logging --- .../csrc/runtime/pjrt_computation_client.cc | 42 ++++++++++--------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 456b83ceafe..9b79c025de5 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -721,45 +721,49 @@ PjRtComputationClient::ExecuteReplicated( tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_result_handle", tsl::profiler::TraceMeLevel::kInfo); - size_t num_returns = results[0].size(); + size_t num_outputs = results[0].size(); + // Next few calls are expected to output vectors of size [1 x outputs] std::vector> hlo_modules = pjrt_computation.executable->GetHloModules().value(); XLA_CHECK_EQ(hlo_modules.size(), 1) << "Expected one HLO module with multiple outputs."; const xla::Shape& result_shape = hlo_modules[0]->result_shape(); - const std::vector& output_shapes = - result_shape.IsTuple() ? hlo_modules[0]->result_shape().tuple_shapes() - : std::vector({result_shape}); - XLA_CHECK_EQ(output_shapes.size(), num_returns) - << "Output shape: " << result_shape.ToString(); - - // TODO(wcromar): Implement this in PJRT C API client - // std::vector output_shapes = - // pjrt_computation.executable->GetOutputShapes().value(); + TF_VLOG(3) << "Processing output with shape " << result_shape.ToString(); + + std::vector output_dims = + pjrt_computation.executable->GetOutputDimensions().value()[0]; + XLA_CHECK_EQ(output_dims.size(), num_outputs); + std::vector output_types = + pjrt_computation.executable->GetOutputElementTypes().value()[0]; + XLA_CHECK_EQ(output_types.size(), num_outputs); + std::vector output_shardings = pjrt_computation.executable->GetOutputShardings().value(); - XLA_CHECK_EQ(output_shardings.size(), num_returns); + XLA_CHECK_EQ(output_shardings.size(), num_outputs); - for (int32_t j = 0; j < num_returns; j++) { + for (int32_t i = 0; i < num_outputs; i++) { std::vector> shards(devices.size()); - for (int32_t i = 0; i < devices.size(); i++) { - XLA_CHECK_EQ(results[i].size(), num_returns); + for (int32_t d = 0; d < devices.size(); d++) { + XLA_CHECK_EQ(results[d].size(), num_outputs); - std::unique_ptr buffer = std::move(results[i][j]); - xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); + std::unique_ptr buffer = std::move(results[d][i]); + xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[d]); XLA_CHECK(pjrt_device == buffer->device()) << "Exepected device: " << pjrt_device->DebugString() << " vs. actual device: " << buffer->device()->DebugString(); - shards[i] = std::make_shared(devices[i], std::move(buffer)); + shards[d] = std::make_shared(devices[d], std::move(buffer)); } auto data = std::make_shared( - spmd_device_str, output_shapes[j], std::move(shards), - output_shardings[j]); + spmd_device_str, + xla::ShapeUtil::MakeShape(output_types[i], output_dims[i]), + std::move(shards), output_shardings[i]); + TF_VLOG(5) << "Created sharded data with shape " + << data->shape().ToString(); data_handles.push_back(data); } } From 4dbac971fbfb5b1b53933245c803a5b0fad8d5e0 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 26 Oct 2023 19:05:13 +0000 Subject: [PATCH 05/24] update docstring --- torch_xla/csrc/runtime/computation_client.h | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index cfd06aef94e..a1223c5ef7e 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -291,17 +291,11 @@ class ComputationClient { const ExecuteComputationOptions& options = ExecuteComputationOptions{}) = 0; - // Executes the computation in replicated mode. - // The size of the arguments vector is the number of replicas to execute, - // and it must match the size of the computation.devices() as well as the - // devices passed as argument. The destination devices for each replicated - // computation come from the devices the Data objects are stored into, which - // must match the devices argument. Within arguments[i], every Data - // object must be coming from the same device. Returns a vector (of the same - // size of the arguments vector) with the results of the parallel execution. - // The result[i], a vector itself, will be the result of the computation fed - // with arguments[i]. If options.explode_tuple is true, the output tuples will - // be decomposed into their single elements. + // Executes the computation on multiple local devices in parallel. + // Each argument to the executable is expected to be sharded in the same order + // as `devices`. If options.explode_tuple is true, the output tuples will be + // decomposed into their single elements. Returns a vector of outputs, each + // of which is sharded in the same order as `devices`. virtual std::vector ExecuteReplicated( const Computation& computation, absl::Span arguments, absl::Span devices, From 5bc1a23ebc487b4f69f012bf043867794db537d2 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 26 Oct 2023 19:13:17 +0000 Subject: [PATCH 06/24] Remove obsolete unit tests --- test/cpp/test_xla_sharding.cpp | 94 ---------------------------------- 1 file changed, 94 deletions(-) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 1bd39e91783..08beb49b7b2 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -348,100 +348,6 @@ TEST_F(XLAShardingTest, CreateTensorsData) { } } -TEST_F(XLAShardingTest, InputHandler) { - if ((torch_xla::runtime::sys_util::GetEnvString( - torch_xla::runtime::env::kEnvPjRtDevice, "") == "") || - (torch_xla::runtime::GetComputationClient()->GetLocalDevices().size() < - 2)) { - GTEST_SKIP() - << "`PJRT_DEVICE` is not set, with more than 2 local devices, (" - << torch_xla::runtime::GetComputationClient()->GetLocalDevices().size() - << " local devices detected)."; - } - - std::vector tensors(2); - auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat)); - xla::Shape tensor_shape = - CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); - std::fill_n(tensors.begin(), tensors.size(), tensor); - std::vector devices = {"TPU:0", "TPU:1"}; - std::vector shardings = { - nullptr, std::make_shared( - xla::HloSharding::Replicate().ToProto(), tensor_shape)}; - std::vector tensors_data = - CreateTensorsData(tensors, shardings, devices); - - std::vector arguments = - UnwrapXlaData(tensors_data); - auto arguments_by_device = ShardingUtil::InputHandler(arguments, devices); - - auto arg0_dev0 = arguments_by_device[0][0]; - auto arg0_dev1 = arguments_by_device[1][0]; - EXPECT_TRUE(XlaDataValuesEqual(arg0_dev0, arg0_dev1, at::kFloat)); - - auto arg1_dev0 = arguments_by_device[0][1]; - auto arg1_dev1 = arguments_by_device[1][1]; - EXPECT_TRUE(XlaDataValuesEqual(arg1_dev0, arg1_dev1, at::kFloat)); -} - -TEST_F(XLAShardingTest, OutputHandler) { - if ((torch_xla::runtime::sys_util::GetEnvString( - torch_xla::runtime::env::kEnvPjRtDevice, "") == "") || - (torch_xla::runtime::GetComputationClient()->GetLocalDevices().size() < - 2)) { - GTEST_SKIP() - << "`PJRT_DEVICE` is not set, with more than 2 local devices, (" - << torch_xla::runtime::GetComputationClient()->GetLocalDevices().size() - << " local devices detected)."; - } - - std::vector devices = - torch_xla::runtime::GetComputationClient()->GetLocalDevices(); - - // Prepare an input vecotr `outputs` with 2 arguments per device. - std::vector> - outputs; - outputs.reserve(devices.size()); - at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat)); - for (auto device : devices) { - outputs.push_back( - UnwrapXlaData(CreateTensorsData({tensor, tensor}, {device, device}))); - } - - xla::Shape tensor_shape = - CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); - auto sharding_spec = std::make_shared( - xla::HloSharding::Tile1D( - CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()), - devices.size()) - .ToProto(), - tensor_shape); - std::vector sharding_specs{sharding_spec, - sharding_spec}; - - // Shard a PjRtData into a PjRtShardedData. - std::vector sharded_outputs = - ShardingUtil::OutputHandler(outputs, sharding_specs, - /*replicated_output=*/true); - EXPECT_EQ(sharded_outputs.size(), 2); - auto shards = torch_xla::runtime::GetComputationClient()->GetDataShards( - sharded_outputs[0]); - EXPECT_EQ(shards.size(), devices.size()); - EXPECT_FALSE( - xla::Shape::Equal().IgnoreLayout()(shards[0]->shape(), tensor_shape)); - - // Wrap sharded data into a PjRtShardedData with `devices.size()` shards. - std::vector wrapped_outputs = - ShardingUtil::OutputHandler(outputs, sharding_specs, - /*replicated_output=*/false); - EXPECT_EQ(wrapped_outputs.size(), 2); - shards = torch_xla::runtime::GetComputationClient()->GetDataShards( - wrapped_outputs[0]); - EXPECT_EQ(shards.size(), devices.size()); - EXPECT_TRUE( - xla::Shape::Equal().IgnoreLayout()(shards[0]->shape(), tensor_shape)); -} - TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) { xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {4, 4}); int64_t n_devices = From be7b3dad574beded61d6429d28be05056e3a7e13 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 26 Oct 2023 19:14:07 +0000 Subject: [PATCH 07/24] improve comment --- torch_xla/csrc/runtime/pjrt_computation_client.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 9b79c025de5..da8d0703f65 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -723,7 +723,7 @@ PjRtComputationClient::ExecuteReplicated( tsl::profiler::TraceMeLevel::kInfo); size_t num_outputs = results[0].size(); - // Next few calls are expected to output vectors of size [1 x outputs] + // Output dims and types are expected to have size [hlo_modules x outputs] std::vector> hlo_modules = pjrt_computation.executable->GetHloModules().value(); XLA_CHECK_EQ(hlo_modules.size(), 1) From 5d48684b92f81676865f9c7a379561090121a336 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 31 Oct 2023 18:14:33 +0000 Subject: [PATCH 08/24] Remove slow calls to get output shapes. --- .../csrc/runtime/pjrt_computation_client.cc | 83 +++++++++---------- 1 file changed, 39 insertions(+), 44 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index da8d0703f65..1d1f9a10603 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -1,6 +1,7 @@ #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include +#include #include #include @@ -652,6 +653,8 @@ PjRtComputationClient::ExecuteReplicated( tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_argument_handle", tsl::profiler::TraceMeLevel::kInfo); + + auto mwait = std::make_shared(arguments.size()); for (int32_t i = 0; i < arguments.size(); ++i) { auto buffer_converter = [&, i]() { auto pjrt_data = @@ -713,59 +716,51 @@ PjRtComputationClient::ExecuteReplicated( })); } - std::vector data_handles; - data_handles.reserve(results.size()); - std::vector dims(results.size()); + size_t num_outputs = results[0].size(); + std::vector data_handles(num_outputs); { tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_result_handle", tsl::profiler::TraceMeLevel::kInfo); - size_t num_outputs = results[0].size(); - - // Output dims and types are expected to have size [hlo_modules x outputs] - std::vector> hlo_modules = - pjrt_computation.executable->GetHloModules().value(); - XLA_CHECK_EQ(hlo_modules.size(), 1) - << "Expected one HLO module with multiple outputs."; - const xla::Shape& result_shape = hlo_modules[0]->result_shape(); - TF_VLOG(3) << "Processing output with shape " << result_shape.ToString(); - - std::vector output_dims = - pjrt_computation.executable->GetOutputDimensions().value()[0]; - XLA_CHECK_EQ(output_dims.size(), num_outputs); - std::vector output_types = - pjrt_computation.executable->GetOutputElementTypes().value()[0]; - XLA_CHECK_EQ(output_types.size(), num_outputs); - - std::vector output_shardings = - pjrt_computation.executable->GetOutputShardings().value(); - - XLA_CHECK_EQ(output_shardings.size(), num_outputs); - - for (int32_t i = 0; i < num_outputs; i++) { - std::vector> shards(devices.size()); - - for (int32_t d = 0; d < devices.size(); d++) { - XLA_CHECK_EQ(results[d].size(), num_outputs); - std::unique_ptr buffer = std::move(results[d][i]); - xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[d]); - XLA_CHECK(pjrt_device == buffer->device()) - << "Exepected device: " << pjrt_device->DebugString() - << " vs. actual device: " << buffer->device()->DebugString(); + xla::HloModuleConfig hlo_config(computation.program_shape()); + std::unique_ptr hlo_modules = ConsumeValue( + xla::HloModule::CreateFromProto(computation.computation().proto(), hlo_config)); + const xla::Shape& result_shape = hlo_modules->result_shape(); + TF_VLOG(3) << "Processing output with shape " << result_shape.ToString(); + const std::vector& output_shapes = + result_shape.IsTuple() ? hlo_modules->result_shape().tuple_shapes() + : std::vector({result_shape}); + xla::OpSharding output_sharding = hlo_modules->spmd_output_sharding().ToProto(); + std::vector output_shardings; + if (output_sharding.type() == xla::OpSharding::TUPLE) { + auto tuple_shardings = output_sharding.tuple_shardings(); + output_shardings = std::vector({tuple_shardings.begin(), tuple_shardings.end()}); + } else { + output_shardings = std::vector({output_sharding}); + } - shards[d] = std::make_shared(devices[d], std::move(buffer)); - } + auto mwait = std::make_shared(num_outputs); + for (int32_t i = 0; i < num_outputs; ++i) { + auto collect_shards = [&, i]() { + std::vector> shards(devices.size()); + for (int32_t d = 0; d < devices.size(); d++) { + std::unique_ptr buffer = std::move(results[d][i]); + shards[d] = std::make_shared(devices[d], std::move(buffer)); + } - auto data = std::make_shared( - spmd_device_str, - xla::ShapeUtil::MakeShape(output_types[i], output_dims[i]), - std::move(shards), output_shardings[i]); - TF_VLOG(5) << "Created sharded data with shape " - << data->shape().ToString(); - data_handles.push_back(data); + data_handles[i] = std::make_shared( + spmd_device_str, + output_shapes[i], + std::move(shards), output_shardings[i]); + TF_VLOG(5) << "Created sharded data with shape " + << data_handles[i]->shape().ToString(); + }; + env::ScheduleIoClosure(util::MultiWait::Completer( + mwait, std::move(collect_shards))); } + mwait->Wait(); } TF_VLOG(1) << "Returning " << data_handles.size() << " sharded outputs."; From 6e1e0d3d28cfdab317bf1e99fc311791769eaf79 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 31 Oct 2023 18:48:11 +0000 Subject: [PATCH 09/24] fix implicit sharding --- .../csrc/runtime/pjrt_computation_client.cc | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 1d1f9a10603..1c64b21869e 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -732,14 +732,22 @@ PjRtComputationClient::ExecuteReplicated( const std::vector& output_shapes = result_shape.IsTuple() ? hlo_modules->result_shape().tuple_shapes() : std::vector({result_shape}); - xla::OpSharding output_sharding = hlo_modules->spmd_output_sharding().ToProto(); + std::vector output_shardings; - if (output_sharding.type() == xla::OpSharding::TUPLE) { - auto tuple_shardings = output_sharding.tuple_shardings(); - output_shardings = std::vector({tuple_shardings.begin(), tuple_shardings.end()}); + if (hlo_modules->has_spmd_output_sharding()) { + xla::OpSharding output_sharding = hlo_modules->spmd_output_sharding().ToProto(); + if (output_sharding.type() == xla::OpSharding::TUPLE) { + auto tuple_shardings = output_sharding.tuple_shardings(); + output_shardings = std::vector({tuple_shardings.begin(), tuple_shardings.end()}); + } else { + output_shardings = std::vector({output_sharding}); + } } else { - output_shardings = std::vector({output_sharding}); + // Without an explicit sharding annotation, the output is implicitly + // replicated + output_shardings = std::vector(output_shapes.size(), xla::HloSharding::Replicate().ToProto()); } + XLA_CHECK_EQ(output_shapes.size(), output_shardings.size()); auto mwait = std::make_shared(num_outputs); for (int32_t i = 0; i < num_outputs; ++i) { From 8b6fff3ddbecae48ac5c59c594368483ce37d391 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 31 Oct 2023 18:48:22 +0000 Subject: [PATCH 10/24] remove declarations of input/output handlers --- torch_xla/csrc/xla_sharding_util.h | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index f6846664790..848b1e62db1 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -58,32 +58,6 @@ class ShardingUtil { bool unroll_windowed_einsum = false, bool bidirectional_windowed_einsum = false); - // Reshuffles arguments (sharded or replicated) on the devices. The - // size of the arguments vector must match that of the sharding_specs. - // The the returned arguments will be in 1:1 correspondence with the `devices` - // vector, so the `i`th result will belong on the `i`th device. - // TODO(yeounoh) avoiding pre-loading of the unpartitioned input arguments - // might improve the performance and save the bandwidth. - static std::vector> - InputHandler(std::vector arguments, - std::vector devices); - - // Processes replicated execution results, where `sharded_results` contains - // `PjRtData` handles and spans the number of devices (outer) and the number - // of arguments (innner). This requires `sharding_specs` of the same size as - // the number of arguments. `sharding_specs` can contain `nullptr` if the - // corresponding result argument is not sharded. The replicated execution - // `replicated_output=true` leaves the results in replicated states, which is - // aligned with the default exepctation of XLA compiler. However, we override - // the compiler's default behavior and allow the execution to return sharded - // results and wrap sharded arguments into `PjRtShardedData`. This returns a - // vector of size that is equal to the number of arguments. - static std::vector OutputHandler( - std::vector> - sharded_results, - std::vector sharding_specs, - bool replicated_output = false); - // Returns the shape of the resulting shards of `tensor` after applying // `sharding`. This assumes the shards will be padded to ensure they all // have the same shape. From 64ef1257bfe88e847ce8bc879a44a0af0baac2e0 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 31 Oct 2023 18:48:57 +0000 Subject: [PATCH 11/24] formatting --- .../csrc/runtime/pjrt_computation_client.cc | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 1c64b21869e..3c1399d1c05 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -725,8 +725,9 @@ PjRtComputationClient::ExecuteReplicated( tsl::profiler::TraceMeLevel::kInfo); xla::HloModuleConfig hlo_config(computation.program_shape()); - std::unique_ptr hlo_modules = ConsumeValue( - xla::HloModule::CreateFromProto(computation.computation().proto(), hlo_config)); + std::unique_ptr hlo_modules = + ConsumeValue(xla::HloModule::CreateFromProto( + computation.computation().proto(), hlo_config)); const xla::Shape& result_shape = hlo_modules->result_shape(); TF_VLOG(3) << "Processing output with shape " << result_shape.ToString(); const std::vector& output_shapes = @@ -735,17 +736,20 @@ PjRtComputationClient::ExecuteReplicated( std::vector output_shardings; if (hlo_modules->has_spmd_output_sharding()) { - xla::OpSharding output_sharding = hlo_modules->spmd_output_sharding().ToProto(); + xla::OpSharding output_sharding = + hlo_modules->spmd_output_sharding().ToProto(); if (output_sharding.type() == xla::OpSharding::TUPLE) { auto tuple_shardings = output_sharding.tuple_shardings(); - output_shardings = std::vector({tuple_shardings.begin(), tuple_shardings.end()}); + output_shardings = std::vector( + {tuple_shardings.begin(), tuple_shardings.end()}); } else { output_shardings = std::vector({output_sharding}); } } else { // Without an explicit sharding annotation, the output is implicitly // replicated - output_shardings = std::vector(output_shapes.size(), xla::HloSharding::Replicate().ToProto()); + output_shardings = std::vector( + output_shapes.size(), xla::HloSharding::Replicate().ToProto()); } XLA_CHECK_EQ(output_shapes.size(), output_shardings.size()); @@ -759,14 +763,13 @@ PjRtComputationClient::ExecuteReplicated( } data_handles[i] = std::make_shared( - spmd_device_str, - output_shapes[i], - std::move(shards), output_shardings[i]); + spmd_device_str, output_shapes[i], std::move(shards), + output_shardings[i]); TF_VLOG(5) << "Created sharded data with shape " << data_handles[i]->shape().ToString(); }; - env::ScheduleIoClosure(util::MultiWait::Completer( - mwait, std::move(collect_shards))); + env::ScheduleIoClosure( + util::MultiWait::Completer(mwait, std::move(collect_shards))); } mwait->Wait(); } From 20be9cad72393571092e9e353bdf31d790b796d9 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 1 Nov 2023 22:58:11 +0000 Subject: [PATCH 12/24] give everything a manual placeholder sharding --- .../csrc/runtime/pjrt_computation_client.cc | 29 ++++--------------- torch_xla/csrc/xla_graph_executor.cpp | 1 + 2 files changed, 7 insertions(+), 23 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 3c1399d1c05..2869b0e8720 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -724,33 +724,16 @@ PjRtComputationClient::ExecuteReplicated( "PjRtComputationClient::ExecuteReplicated_result_handle", tsl::profiler::TraceMeLevel::kInfo); - xla::HloModuleConfig hlo_config(computation.program_shape()); - std::unique_ptr hlo_modules = - ConsumeValue(xla::HloModule::CreateFromProto( - computation.computation().proto(), hlo_config)); - const xla::Shape& result_shape = hlo_modules->result_shape(); + const xla::Shape& result_shape = computation.program_shape().result(); TF_VLOG(3) << "Processing output with shape " << result_shape.ToString(); const std::vector& output_shapes = - result_shape.IsTuple() ? hlo_modules->result_shape().tuple_shapes() + result_shape.IsTuple() ? result_shape.tuple_shapes() : std::vector({result_shape}); - std::vector output_shardings; - if (hlo_modules->has_spmd_output_sharding()) { - xla::OpSharding output_sharding = - hlo_modules->spmd_output_sharding().ToProto(); - if (output_sharding.type() == xla::OpSharding::TUPLE) { - auto tuple_shardings = output_sharding.tuple_shardings(); - output_shardings = std::vector( - {tuple_shardings.begin(), tuple_shardings.end()}); - } else { - output_shardings = std::vector({output_sharding}); - } - } else { - // Without an explicit sharding annotation, the output is implicitly - // replicated - output_shardings = std::vector( - output_shapes.size(), xla::HloSharding::Replicate().ToProto()); - } + // Without an explicit sharding annotation, the output is implicitly + // replicated + std::vector output_shardings = std::vector( + output_shapes.size(), xla::HloSharding::Manual().ToProto()); XLA_CHECK_EQ(output_shapes.size(), output_shardings.size()); auto mwait = std::make_shared(num_outputs); diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 2b193bdbb96..83717b299f8 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1001,6 +1001,7 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( if (async->tensors_data[i] != nullptr) { async->tensors_data[i]->Assign(*results[i]); } else { + XLA_ERROR() << "TODO: remove this path if it is not being used."; async->tensors_data[i] = std::move(results[i]); } } From ac6141377f6bad9d847054a94fbbeba4a9f7ba22 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 2 Nov 2023 00:07:18 +0000 Subject: [PATCH 13/24] see if CI passes --- torch_xla/csrc/runtime/pjrt_computation_client.cc | 6 +++--- torch_xla/csrc/xla_graph_executor.cpp | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 2869b0e8720..ff92dc368e9 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -729,11 +729,11 @@ PjRtComputationClient::ExecuteReplicated( const std::vector& output_shapes = result_shape.IsTuple() ? result_shape.tuple_shapes() : std::vector({result_shape}); + XLA_CHECK_EQ(output_shapes.size(), num_outputs); - // Without an explicit sharding annotation, the output is implicitly - // replicated + // HACK: we don't use the sharding on this DataPtr anyway std::vector output_shardings = std::vector( - output_shapes.size(), xla::HloSharding::Manual().ToProto()); + output_shapes.size(), xla::HloSharding::Unknown().ToProto()); XLA_CHECK_EQ(output_shapes.size(), output_shardings.size()); auto mwait = std::make_shared(num_outputs); diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 83717b299f8..40147263d64 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1001,7 +1001,8 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( if (async->tensors_data[i] != nullptr) { async->tensors_data[i]->Assign(*results[i]); } else { - XLA_ERROR() << "TODO: remove this path if it is not being used."; + // TODO see if this passes CI + XLA_CHECK(!std::dynamic_pointer_cast(results[i])->HasSharding() || std::dynamic_pointer_cast(results[i])->GetSharding().type() != xla::OpSharding::UNKNOWN) << "TODO: remove this path if it is not being used."; async->tensors_data[i] = std::move(results[i]); } } From 806a7bd1ec1fdd0d2d2421a53c34b46726cd281e Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 2 Nov 2023 00:08:10 +0000 Subject: [PATCH 14/24] formatting --- torch_xla/csrc/runtime/pjrt_computation_client.cc | 5 +++-- torch_xla/csrc/xla_graph_executor.cpp | 10 +++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index ff92dc368e9..b01a3953b04 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -732,8 +732,9 @@ PjRtComputationClient::ExecuteReplicated( XLA_CHECK_EQ(output_shapes.size(), num_outputs); // HACK: we don't use the sharding on this DataPtr anyway - std::vector output_shardings = std::vector( - output_shapes.size(), xla::HloSharding::Unknown().ToProto()); + std::vector output_shardings = + std::vector(output_shapes.size(), + xla::HloSharding::Unknown().ToProto()); XLA_CHECK_EQ(output_shapes.size(), output_shardings.size()); auto mwait = std::make_shared(num_outputs); diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 40147263d64..62eff2d1a27 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1002,7 +1002,15 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( async->tensors_data[i]->Assign(*results[i]); } else { // TODO see if this passes CI - XLA_CHECK(!std::dynamic_pointer_cast(results[i])->HasSharding() || std::dynamic_pointer_cast(results[i])->GetSharding().type() != xla::OpSharding::UNKNOWN) << "TODO: remove this path if it is not being used."; + XLA_CHECK( + !std::dynamic_pointer_cast( + results[i]) + ->HasSharding() || + std::dynamic_pointer_cast( + results[i]) + ->GetSharding() + .type() != xla::OpSharding::UNKNOWN) + << "TODO: remove this path if it is not being used."; async->tensors_data[i] = std::move(results[i]); } } From 787b568968b6f3cec3f1b12dc1f4f2a20343c90a Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 3 Nov 2023 17:05:48 +0000 Subject: [PATCH 15/24] Shard parameter and output handling --- torch_xla/csrc/runtime/BUILD | 1 + .../csrc/runtime/pjrt_computation_client.cc | 41 ++++++++++--------- .../csrc/runtime/pjrt_computation_client.h | 4 ++ 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index fa7e3578729..df2b104876f 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -101,6 +101,7 @@ cc_library( "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/platform/cloud:gcs_file_system", + "@tsl//tsl/platform:env", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index b01a3953b04..eb608a838b2 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -654,9 +654,13 @@ PjRtComputationClient::ExecuteReplicated( "PjRtComputationClient::ExecuteReplicated_argument_handle", tsl::profiler::TraceMeLevel::kInfo); - auto mwait = std::make_shared(arguments.size()); - for (int32_t i = 0; i < arguments.size(); ++i) { - auto buffer_converter = [&, i]() { + util::MultiWait mwait(arguments.size()); + // TODO: tune and document cost estimate + pool_.ParallelFor(arguments.size(), 100000, [&](int64_t start, int64_t end) { + tsl::profiler::TraceMe activity( + "PjRtComputationClient::ExecuteReplicated_argument_handle_shard", + tsl::profiler::TraceMeLevel::kInfo); + for (int32_t i = start; i < end; ++i) { auto pjrt_data = std::dynamic_pointer_cast(arguments[i]); XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size()) @@ -670,11 +674,12 @@ PjRtComputationClient::ExecuteReplicated( XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); argument_handles[d][i] = shard->buffer.get(); + mwait.Done(); } counter.DecrementCount(); }; thread::Schedule(std::move(buffer_converter)); - } + }); counter.Wait(); } @@ -731,15 +736,13 @@ PjRtComputationClient::ExecuteReplicated( : std::vector({result_shape}); XLA_CHECK_EQ(output_shapes.size(), num_outputs); - // HACK: we don't use the sharding on this DataPtr anyway - std::vector output_shardings = - std::vector(output_shapes.size(), - xla::HloSharding::Unknown().ToProto()); - XLA_CHECK_EQ(output_shapes.size(), output_shardings.size()); - - auto mwait = std::make_shared(num_outputs); - for (int32_t i = 0; i < num_outputs; ++i) { - auto collect_shards = [&, i]() { + util::MultiWait mwait(num_outputs); + // TODO: tune and document cost estimate + pool_.ParallelFor(num_outputs, 100000, [&](int64_t start, int64_t end) { + tsl::profiler::TraceMe activity( + "PjRtComputationClient::ExecuteReplicated_result_handle_shard", + tsl::profiler::TraceMeLevel::kInfo); + for (int32_t i = start; i < end; ++i) { std::vector> shards(devices.size()); for (int32_t d = 0; d < devices.size(); d++) { std::unique_ptr buffer = std::move(results[d][i]); @@ -748,14 +751,14 @@ PjRtComputationClient::ExecuteReplicated( data_handles[i] = std::make_shared( spmd_device_str, output_shapes[i], std::move(shards), - output_shardings[i]); + // HACK: we don't use the sharding on this DataPtr anyway + xla::HloSharding::Unknown().ToProto()); TF_VLOG(5) << "Created sharded data with shape " << data_handles[i]->shape().ToString(); - }; - env::ScheduleIoClosure( - util::MultiWait::Completer(mwait, std::move(collect_shards))); - } - mwait->Wait(); + mwait.Done(); + } + }); + mwait.Wait(); } TF_VLOG(1) << "Returning " << data_handles.size() << " sharded outputs."; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 77866953ed0..ce9a744f4d2 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -12,6 +12,8 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/operation_manager.h" #include "torch_xla/csrc/runtime/util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/threadpool.h" #include "xla/client/xla_computation.h" #include "xla/literal.h" #include "xla/pjrt/pjrt_client.h" @@ -110,6 +112,7 @@ class PjRtComputationClient : public ComputationClient { std::unordered_map string_to_device_; std::shared_ptr> replication_devices_; OperationManager operation_manager_; + tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool(tsl::Env::Default(), "pjrt", std::thread::hardware_concurrency()); xla::PjRtDevice* StringToPjRtDevice(const std::string& device); @@ -234,6 +237,7 @@ class PjRtComputationClient : public ComputationClient { // Use XLA replication to re-assemble the sharded data. std::shared_ptr ReplicateShardedData(const DataPtr& handle); + }; } // namespace runtime From 109b5de5e96727d3c1cb0012f3b2b4caf794bd39 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 3 Nov 2023 20:02:53 +0000 Subject: [PATCH 16/24] Use absl::BlockingCounter --- .../csrc/runtime/pjrt_computation_client.cc | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index eb608a838b2..5225e66bc09 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -646,17 +646,16 @@ PjRtComputationClient::ExecuteReplicated( const PjRtComputation& pjrt_computation = dynamic_cast(computation); - absl::BlockingCounter counter(arguments.size()); - std::vector> argument_handles( - devices.size(), std::vector(arguments.size())); { tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_argument_handle", tsl::profiler::TraceMeLevel::kInfo); - util::MultiWait mwait(arguments.size()); + absl::BlockingCounter counter(arguments.size()); + std::vector> argument_handles( + devices.size(), std::vector(arguments.size())); // TODO: tune and document cost estimate - pool_.ParallelFor(arguments.size(), 100000, [&](int64_t start, int64_t end) { + pool_.ParallelFor(arguments.size(), 30000, [&](int64_t start, int64_t end) { tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_argument_handle_shard", tsl::profiler::TraceMeLevel::kInfo); @@ -674,7 +673,6 @@ PjRtComputationClient::ExecuteReplicated( XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); argument_handles[d][i] = shard->buffer.get(); - mwait.Done(); } counter.DecrementCount(); }; @@ -736,9 +734,9 @@ PjRtComputationClient::ExecuteReplicated( : std::vector({result_shape}); XLA_CHECK_EQ(output_shapes.size(), num_outputs); - util::MultiWait mwait(num_outputs); + absl::BlockingCounter counter(num_outputs); // TODO: tune and document cost estimate - pool_.ParallelFor(num_outputs, 100000, [&](int64_t start, int64_t end) { + pool_.ParallelFor(num_outputs, 30000, [&](int64_t start, int64_t end) { tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_result_handle_shard", tsl::profiler::TraceMeLevel::kInfo); @@ -755,10 +753,10 @@ PjRtComputationClient::ExecuteReplicated( xla::HloSharding::Unknown().ToProto()); TF_VLOG(5) << "Created sharded data with shape " << data_handles[i]->shape().ToString(); - mwait.Done(); + counter.DecrementCount(); } }); - mwait.Wait(); + counter.Wait(); } TF_VLOG(1) << "Returning " << data_handles.size() << " sharded outputs."; From 6e1c28ae240d4e0617a043092b629815860b2730 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 3 Nov 2023 20:03:46 +0000 Subject: [PATCH 17/24] formatting --- torch_xla/csrc/runtime/pjrt_computation_client.cc | 8 ++++---- torch_xla/csrc/runtime/pjrt_computation_client.h | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 5225e66bc09..cfd6b76b659 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -657,8 +657,8 @@ PjRtComputationClient::ExecuteReplicated( // TODO: tune and document cost estimate pool_.ParallelFor(arguments.size(), 30000, [&](int64_t start, int64_t end) { tsl::profiler::TraceMe activity( - "PjRtComputationClient::ExecuteReplicated_argument_handle_shard", - tsl::profiler::TraceMeLevel::kInfo); + "PjRtComputationClient::ExecuteReplicated_argument_handle_shard", + tsl::profiler::TraceMeLevel::kInfo); for (int32_t i = start; i < end; ++i) { auto pjrt_data = std::dynamic_pointer_cast(arguments[i]); @@ -738,8 +738,8 @@ PjRtComputationClient::ExecuteReplicated( // TODO: tune and document cost estimate pool_.ParallelFor(num_outputs, 30000, [&](int64_t start, int64_t end) { tsl::profiler::TraceMe activity( - "PjRtComputationClient::ExecuteReplicated_result_handle_shard", - tsl::profiler::TraceMeLevel::kInfo); + "PjRtComputationClient::ExecuteReplicated_result_handle_shard", + tsl::profiler::TraceMeLevel::kInfo); for (int32_t i = start; i < end; ++i) { std::vector> shards(devices.size()); for (int32_t d = 0; d < devices.size(); d++) { diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index ce9a744f4d2..0d8c7629d76 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -112,7 +112,8 @@ class PjRtComputationClient : public ComputationClient { std::unordered_map string_to_device_; std::shared_ptr> replication_devices_; OperationManager operation_manager_; - tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool(tsl::Env::Default(), "pjrt", std::thread::hardware_concurrency()); + tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool( + tsl::Env::Default(), "pjrt", std::thread::hardware_concurrency()); xla::PjRtDevice* StringToPjRtDevice(const std::string& device); @@ -237,7 +238,6 @@ class PjRtComputationClient : public ComputationClient { // Use XLA replication to re-assemble the sharded data. std::shared_ptr ReplicateShardedData(const DataPtr& handle); - }; } // namespace runtime From 464c6c3bbeebd51e740d9762f8cf145845d50646 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 27 Nov 2023 18:16:29 +0000 Subject: [PATCH 18/24] fix merge --- torch_xla/csrc/runtime/pjrt_computation_client.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index cfd6b76b659..c8e5aa5884c 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -646,14 +646,14 @@ PjRtComputationClient::ExecuteReplicated( const PjRtComputation& pjrt_computation = dynamic_cast(computation); + std::vector> argument_handles( + devices.size(), std::vector(arguments.size())); { tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_argument_handle", tsl::profiler::TraceMeLevel::kInfo); absl::BlockingCounter counter(arguments.size()); - std::vector> argument_handles( - devices.size(), std::vector(arguments.size())); // TODO: tune and document cost estimate pool_.ParallelFor(arguments.size(), 30000, [&](int64_t start, int64_t end) { tsl::profiler::TraceMe activity( @@ -676,7 +676,6 @@ PjRtComputationClient::ExecuteReplicated( } counter.DecrementCount(); }; - thread::Schedule(std::move(buffer_converter)); }); counter.Wait(); } From 11bf3ff77140765a53acb89667740c2af5872848 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 28 Nov 2023 22:11:49 +0000 Subject: [PATCH 19/24] Assign valid output shardings --- torch_xla/csrc/runtime/pjrt_computation_client.cc | 7 +++++-- torch_xla/csrc/runtime/pjrt_computation_client.h | 5 ++++- torch_xla/csrc/xla_graph_executor.cpp | 10 ---------- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index c8e5aa5884c..08295c1ed8a 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -733,6 +733,10 @@ PjRtComputationClient::ExecuteReplicated( : std::vector({result_shape}); XLA_CHECK_EQ(output_shapes.size(), num_outputs); + XLA_CHECK(pjrt_computation.output_shardings_.has_value()); + const std::vector& output_shardings = *pjrt_computation.output_shardings_; + XLA_CHECK_EQ(output_shardings.size(), num_outputs); + absl::BlockingCounter counter(num_outputs); // TODO: tune and document cost estimate pool_.ParallelFor(num_outputs, 30000, [&](int64_t start, int64_t end) { @@ -748,8 +752,7 @@ PjRtComputationClient::ExecuteReplicated( data_handles[i] = std::make_shared( spmd_device_str, output_shapes[i], std::move(shards), - // HACK: we don't use the sharding on this DataPtr anyway - xla::HloSharding::Unknown().ToProto()); + output_shardings[i]); TF_VLOG(5) << "Created sharded data with shape " << data_handles[i]->shape().ToString(); counter.DecrementCount(); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 0d8c7629d76..93760ee6b05 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -231,9 +231,12 @@ class PjRtComputationClient : public ComputationClient { std::vector devices, std::unique_ptr executable) : Computation(std::move(computation), std::move(devices)), - executable(std::move(executable)) {} + executable(std::move(executable)) { + output_shardings_ = this->executable->GetOutputShardings(); + } std::unique_ptr executable; + std::optional> output_shardings_; }; // Use XLA replication to re-assemble the sharded data. diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 62eff2d1a27..2b193bdbb96 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1001,16 +1001,6 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( if (async->tensors_data[i] != nullptr) { async->tensors_data[i]->Assign(*results[i]); } else { - // TODO see if this passes CI - XLA_CHECK( - !std::dynamic_pointer_cast( - results[i]) - ->HasSharding() || - std::dynamic_pointer_cast( - results[i]) - ->GetSharding() - .type() != xla::OpSharding::UNKNOWN) - << "TODO: remove this path if it is not being used."; async->tensors_data[i] = std::move(results[i]); } } From dba6d1caa44e9774c5cbf804bb50e7e23dddb1da Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 29 Nov 2023 00:05:52 +0000 Subject: [PATCH 20/24] tune and document costs --- .../csrc/runtime/pjrt_computation_client.cc | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 08295c1ed8a..4ee45a68d24 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -654,11 +654,11 @@ PjRtComputationClient::ExecuteReplicated( tsl::profiler::TraceMeLevel::kInfo); absl::BlockingCounter counter(arguments.size()); - // TODO: tune and document cost estimate - pool_.ParallelFor(arguments.size(), 30000, [&](int64_t start, int64_t end) { - tsl::profiler::TraceMe activity( - "PjRtComputationClient::ExecuteReplicated_argument_handle_shard", - tsl::profiler::TraceMeLevel::kInfo); + + // Time in nanoseconds that it takes to prepare an argument. Used to tune + // number of threads spawned by ParallelFor. Measured on 2023/11/28. + static constexpr int64_t argument_handle_cost_ns = 10000; + pool_.ParallelFor(arguments.size(), argument_handle_cost_ns, [&](int64_t start, int64_t end) { for (int32_t i = start; i < end; ++i) { auto pjrt_data = std::dynamic_pointer_cast(arguments[i]); @@ -738,11 +738,11 @@ PjRtComputationClient::ExecuteReplicated( XLA_CHECK_EQ(output_shardings.size(), num_outputs); absl::BlockingCounter counter(num_outputs); - // TODO: tune and document cost estimate - pool_.ParallelFor(num_outputs, 30000, [&](int64_t start, int64_t end) { - tsl::profiler::TraceMe activity( - "PjRtComputationClient::ExecuteReplicated_result_handle_shard", - tsl::profiler::TraceMeLevel::kInfo); + + // Time in nanoseconds that it takes to process a result buffer. + // Measured on 2023/11/28. + static constexpr int64_t result_handle_cost_ns = 10000; + pool_.ParallelFor(num_outputs, result_handle_cost_ns, [&](int64_t start, int64_t end) { for (int32_t i = start; i < end; ++i) { std::vector> shards(devices.size()); for (int32_t d = 0; d < devices.size(); d++) { From 9dcd75e7e6e91e9243899d4b9988d853bb504020 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 29 Nov 2023 00:11:00 +0000 Subject: [PATCH 21/24] formatting --- .../csrc/runtime/pjrt_computation_client.cc | 79 ++++++++++--------- 1 file changed, 43 insertions(+), 36 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 4ee45a68d24..58d7ca1f3b2 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -658,25 +658,28 @@ PjRtComputationClient::ExecuteReplicated( // Time in nanoseconds that it takes to prepare an argument. Used to tune // number of threads spawned by ParallelFor. Measured on 2023/11/28. static constexpr int64_t argument_handle_cost_ns = 10000; - pool_.ParallelFor(arguments.size(), argument_handle_cost_ns, [&](int64_t start, int64_t end) { - for (int32_t i = start; i < end; ++i) { - auto pjrt_data = - std::dynamic_pointer_cast(arguments[i]); - XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size()) - << "Expected one shard per device"; - - for (int32_t d = 0; d < devices.size(); d++) { - std::shared_ptr shard = pjrt_data->shards[d]; - - xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[d]); - XLA_CHECK_EQ(shard->buffer->device(), pjrt_device); - XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); - - argument_handles[d][i] = shard->buffer.get(); - } - counter.DecrementCount(); - }; - }); + pool_.ParallelFor( + arguments.size(), argument_handle_cost_ns, + [&](int64_t start, int64_t end) { + for (int32_t i = start; i < end; ++i) { + auto pjrt_data = + std::dynamic_pointer_cast(arguments[i]); + XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size()) + << "Expected one shard per device"; + + for (int32_t d = 0; d < devices.size(); d++) { + std::shared_ptr shard = pjrt_data->shards[d]; + + xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[d]); + XLA_CHECK_EQ(shard->buffer->device(), pjrt_device); + XLA_CHECK(pjrt_device->IsAddressable()) + << pjrt_device->DebugString(); + + argument_handles[d][i] = shard->buffer.get(); + } + counter.DecrementCount(); + }; + }); counter.Wait(); } @@ -734,7 +737,8 @@ PjRtComputationClient::ExecuteReplicated( XLA_CHECK_EQ(output_shapes.size(), num_outputs); XLA_CHECK(pjrt_computation.output_shardings_.has_value()); - const std::vector& output_shardings = *pjrt_computation.output_shardings_; + const std::vector& output_shardings = + *pjrt_computation.output_shardings_; XLA_CHECK_EQ(output_shardings.size(), num_outputs); absl::BlockingCounter counter(num_outputs); @@ -742,22 +746,25 @@ PjRtComputationClient::ExecuteReplicated( // Time in nanoseconds that it takes to process a result buffer. // Measured on 2023/11/28. static constexpr int64_t result_handle_cost_ns = 10000; - pool_.ParallelFor(num_outputs, result_handle_cost_ns, [&](int64_t start, int64_t end) { - for (int32_t i = start; i < end; ++i) { - std::vector> shards(devices.size()); - for (int32_t d = 0; d < devices.size(); d++) { - std::unique_ptr buffer = std::move(results[d][i]); - shards[d] = std::make_shared(devices[d], std::move(buffer)); - } - - data_handles[i] = std::make_shared( - spmd_device_str, output_shapes[i], std::move(shards), - output_shardings[i]); - TF_VLOG(5) << "Created sharded data with shape " - << data_handles[i]->shape().ToString(); - counter.DecrementCount(); - } - }); + pool_.ParallelFor( + num_outputs, result_handle_cost_ns, [&](int64_t start, int64_t end) { + for (int32_t i = start; i < end; ++i) { + std::vector> shards(devices.size()); + for (int32_t d = 0; d < devices.size(); d++) { + std::unique_ptr buffer = + std::move(results[d][i]); + shards[d] = + std::make_shared(devices[d], std::move(buffer)); + } + + data_handles[i] = std::make_shared( + spmd_device_str, output_shapes[i], std::move(shards), + output_shardings[i]); + TF_VLOG(5) << "Created sharded data with shape " + << data_handles[i]->shape().ToString(); + counter.DecrementCount(); + } + }); counter.Wait(); } From c2b88120ec6d4473c8c924c91d3afbca2789d1b3 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 29 Nov 2023 00:20:29 +0000 Subject: [PATCH 22/24] implicitly replicate output to match outputhandler --- torch_xla/csrc/runtime/pjrt_computation_client.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 58d7ca1f3b2..54f32a86932 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -736,9 +736,13 @@ PjRtComputationClient::ExecuteReplicated( : std::vector({result_shape}); XLA_CHECK_EQ(output_shapes.size(), num_outputs); - XLA_CHECK(pjrt_computation.output_shardings_.has_value()); const std::vector& output_shardings = - *pjrt_computation.output_shardings_; + pjrt_computation.output_shardings_ + ? *pjrt_computation.output_shardings_ + : + // Without an explicit sharding annotation, the output is implicitly + // replicated + std::vector(num_outputs, xla::HloSharding::Replicate().ToProto()); XLA_CHECK_EQ(output_shardings.size(), num_outputs); absl::BlockingCounter counter(num_outputs); From 9bc594e00a915b2057823339f58d734378f0f013 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 29 Nov 2023 21:24:41 +0000 Subject: [PATCH 23/24] clarify ReplicateShardedData --- .../csrc/runtime/pjrt_computation_client.cc | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 54f32a86932..76a9d7c19b6 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -390,11 +390,13 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice( std::shared_ptr PjRtComputationClient::ReplicateShardedData( const ComputationClient::DataPtr& handle) { - if (PjRtShardedData* sharded_data = - dynamic_cast(handle.get())) { + if (auto unsharded_data = std::dynamic_pointer_cast(handle)) { + return unsharded_data; + } else if (auto sharded_data = + std::dynamic_pointer_cast(handle)) { XLA_COUNTER("ReplicateShardedData", 1); - TF_VLOG(1) << "ReplicateShardedData (handle=" << handle->GetHandle() - << ", shape=" << handle->shape() << ")"; + TF_VLOG(1) << "ReplicateShardedData (handle=" << sharded_data->GetHandle() + << ", shape=" << sharded_data->shape() << ")"; if (sharded_data->GetSharding().type() == xla::OpSharding::REPLICATED) { // Data is replicated, return the first shard return sharded_data->shards[0]; @@ -432,8 +434,9 @@ PjRtComputationClient::ReplicateShardedData( torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; - auto sharded_results = ExecuteReplicated( - *computations.front(), {handle}, GetLocalDevices(), execute_options); + auto sharded_results = + ExecuteReplicated(*computations.front(), {sharded_data}, + GetLocalDevices(), execute_options); XLA_CHECK(sharded_results.size() > 0) << "empty ExecuteReplicated results returned."; XLA_CHECK(sharded_results.size() == 1) @@ -442,9 +445,9 @@ PjRtComputationClient::ReplicateShardedData( return std::dynamic_pointer_cast(sharded_results[0]) ->shards[0]; } - auto pjrt_data = std::dynamic_pointer_cast(handle); - XLA_CHECK(pjrt_data) << "Data must be PjRtData or PjRtShardedData."; - return pjrt_data; + + XLA_ERROR() << "Data must be PjRtData or PjRtShardedData, got " + << handle->ToString(); } std::vector PjRtComputationClient::TransferFromServer( From 674f2c3839b2e1280b0910cc87ab925a45aeb343 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 30 Nov 2023 18:22:02 +0000 Subject: [PATCH 24/24] fix merge --- torch_xla/csrc/runtime/pjrt_computation_client.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 76a9d7c19b6..0bd42b3ad6a 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -467,8 +467,8 @@ std::vector PjRtComputationClient::TransferFromServer( XLA_CHECK(pjrt_data); xla::Literal& literal = - literals.emplace_back(host_output_shape(pjrt_data.buffer.get())); - futures.push_back(pjrt_data.buffer->ToLiteral(&literal)); + literals.emplace_back(host_output_shape(pjrt_data->buffer.get())); + futures.push_back(pjrt_data->buffer->ToLiteral(&literal)); total_size += literal.size_bytes(); }