diff --git a/WORKSPACE b/WORKSPACE index bf7d8c0c137..600d280519e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -42,6 +42,7 @@ http_archive( "//openxla_patches:f16_abi_clang.diff", "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:constexpr_return.diff", + "//openxla_patches:pjrt_c_api_dynamic_dimensions.diff", ], strip_prefix = "xla-97a5f819faf9ff793b7ba68ff1f31f74f9459c18", urls = [ diff --git a/openxla_patches/pjrt_c_api_dynamic_dimensions.diff b/openxla_patches/pjrt_c_api_dynamic_dimensions.diff new file mode 100644 index 00000000000..ee1ec00eced --- /dev/null +++ b/openxla_patches/pjrt_c_api_dynamic_dimensions.diff @@ -0,0 +1,76 @@ +# Partial backport of 6308dba2903e78961ac4122f361bc91b09f36891. Remove in next +# pin update. +diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc +index ef0b6686c..c0341e81e 100644 +--- a/xla/pjrt/pjrt_c_api_client.cc ++++ b/xla/pjrt/pjrt_c_api_client.cc +@@ -1584,6 +1584,34 @@ bool PjRtCApiBuffer::has_dynamic_dimensions() const { + return args.num_dynamic_dims > 0; + } + ++absl::Span PjRtCApiBuffer::is_dynamic_dimension() const { ++ { ++ absl::MutexLock lock(&mu_); ++ if (!is_dynamic_dimension_.has_value()) { ++ absl::InlinedVector& is_dynamic_dimension_value = ++ is_dynamic_dimension_.emplace(); ++ is_dynamic_dimension_value.assign(dimensions().size(), false); ++ ++ PJRT_Buffer_DynamicDimensionIndices_Args args; ++ args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE; ++ args.priv = nullptr; ++ args.buffer = buffer_.get(); ++ const PJRT_Api* api = pjrt_c_api(); ++ std::unique_ptr error( ++ api->PJRT_Buffer_DynamicDimensionIndices(&args), ++ pjrt::MakeErrorDeleter(api)); ++ if (error && pjrt::GetErrorCode(error.get(), api) == ++ PJRT_Error_Code_UNIMPLEMENTED) { ++ return *is_dynamic_dimension_; ++ } ++ for (int i = 0; i < args.num_dynamic_dims; ++i) { ++ is_dynamic_dimension_value[args.dynamic_dim_indices[i]] = true; ++ } ++ } ++ } ++ return *is_dynamic_dimension_; ++} ++ + StatusOr> PjRtCApiBuffer::logical_dimensions() { + PJRT_Buffer_UnpaddedDimensions_Args args; + args.struct_size = PJRT_Buffer_UnpaddedDimensions_Args_STRUCT_SIZE; +diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h +index 9c460f246..279608e60 100644 +--- a/xla/pjrt/pjrt_c_api_client.h ++++ b/xla/pjrt/pjrt_c_api_client.h +@@ -27,6 +27,7 @@ limitations under the License. + #include + + #include "absl/container/flat_hash_map.h" ++#include "absl/container/inlined_vector.h" + #include "absl/log/check.h" + #include "absl/log/log.h" + #include "absl/strings/string_view.h" +@@ -369,11 +370,7 @@ class PjRtCApiBuffer : public PjRtBuffer { + + bool has_dynamic_dimensions() const override; + +- absl::Span is_dynamic_dimension() const override { +- LOG(FATAL) << "PjRtCApiBuffer::is_dynamic_dimension() not implemented. " +- << "Considering using has_dynamic_dimensions() or " +- "logical_dimensions() if applicable."; +- } ++ absl::Span is_dynamic_dimension() const override; + + StatusOr> logical_dimensions() override; + +@@ -455,6 +452,9 @@ class PjRtCApiBuffer : public PjRtBuffer { + std::shared_ptr::Promise> readiness_promise_; + // Set and cached the first time layout() is called. + mutable std::optional layout_; ++ // Set and cached the first time is_dynamic_dimension() is called. ++ mutable std::optional> ++ is_dynamic_dimension_; + // Used to synchronize concurrent setting of cached values. + mutable absl::Mutex mu_; + }; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index eea50940b78..5a965d318a6 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -70,6 +70,15 @@ std::unordered_map build_index_map( return device_index; } +// Builds the xla::Shape of the output xla::Literal on the host. +xla::Shape host_output_shape(xla::PjRtBuffer* buffer) { + xla::Shape shape = xla::ShapeUtil::MakeShape( + buffer->element_type(), buffer->logical_dimensions().value()); + *shape.mutable_layout() = buffer->layout(); + + return xla::ShapeUtil::DeviceShapeToHostShape(shape); +} + } // namespace std::string PjRtComputationClient::PjRtDeviceToString( @@ -424,9 +433,8 @@ std::vector PjRtComputationClient::TransferFromServer( auto new_handle = ReplicateShardedData(handle); const PjRtData& pjrt_data = dynamic_cast(*new_handle); - xla::Shape target_shape = xla::ShapeUtil::DeviceShapeToHostShape( - pjrt_data.buffer->logical_on_device_shape().value()); - auto& literal = literals.emplace_back(target_shape); + auto& literal = + literals.emplace_back(host_output_shape(pjrt_data.buffer.get())); XLA_CHECK_OK(pjrt_data.buffer->ToLiteralSync(&literal)); total_size += literal.size_bytes(); @@ -569,8 +577,8 @@ PjRtComputationClient::ExecuteComputation( for (auto& result : results) { std::unique_ptr buffer = std::move(result); - std::shared_ptr data = std::make_shared( - device, buffer->on_device_shape(), std::move(buffer)); + std::shared_ptr data = + std::make_shared(device, std::move(buffer)); datas.push_back(data); } @@ -697,8 +705,8 @@ PjRtComputationClient::ExecuteReplicated( << "Exepcted device: " << pjrt_device->DebugString() << " vs. actual device: " << buffer->device()->DebugString(); - std::shared_ptr data = std::make_shared( - devices[i], buffer->on_device_shape(), std::move(buffer)); + std::shared_ptr data = + std::make_shared(devices[i], std::move(buffer)); datas.push_back(data); } data_handles.push_back(datas); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index a8e7b3e6985..a4c4f58aab0 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -164,6 +164,12 @@ class PjRtComputationClient : public ComputationClient { std::shared_ptr buffer) : Data(std::move(device), std::move(device_shape)), buffer(buffer) {} + PjRtData(std::string device, std::shared_ptr buffer) + : Data(std::move(device), + xla::Shape(buffer->element_type(), buffer->dimensions(), + buffer->is_dynamic_dimension(), {})), + buffer(buffer) {} + OpaqueHandle GetOpaqueHandle() override { XLA_CHECK(HasValue()) << "buffer with shape " << shape().ToString() << " on device "