From 984b3036ea873dd6a51b7a7b08302192f98dbbd0 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 7 Sep 2023 22:00:27 +0000 Subject: [PATCH] Patch in `is_dynamic_dimension` --- WORKSPACE | 1 + .../pjrt_c_api_dynamic_dimensions.diff | 46 +++++++++++++++++++ .../csrc/runtime/pjrt_computation_client.cc | 25 ++-------- .../csrc/runtime/pjrt_computation_client.h | 6 +++ 4 files changed, 57 insertions(+), 21 deletions(-) create mode 100644 openxla_patches/pjrt_c_api_dynamic_dimensions.diff diff --git a/WORKSPACE b/WORKSPACE index 790f6d8f31cb..57d34a791f0a 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -42,6 +42,7 @@ http_archive( "//openxla_patches:f16_abi_clang.diff", "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:constexpr_return.diff", + "//openxla_patches:pjrt_c_api_dynamic_dimensions.diff", ], strip_prefix = "xla-cd2cf5c34931e4fc1cacf83bfc480a5b93f05f6d", urls = [ diff --git a/openxla_patches/pjrt_c_api_dynamic_dimensions.diff b/openxla_patches/pjrt_c_api_dynamic_dimensions.diff new file mode 100644 index 000000000000..7b2ace1602d0 --- /dev/null +++ b/openxla_patches/pjrt_c_api_dynamic_dimensions.diff @@ -0,0 +1,46 @@ +diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc +index 565aa2208..2278ab6c4 100644 +--- a/xla/pjrt/pjrt_c_api_client.cc ++++ b/xla/pjrt/pjrt_c_api_client.cc +@@ -1658,6 +1658,24 @@ bool PjRtCApiBuffer::has_dynamic_dimensions() const { + return args.num_dynamic_dims > 0; + } + ++ ++absl::Span PjRtCApiBuffer::is_dynamic_dimension() const { ++ PJRT_Buffer_DynamicDimensionIndices_Args args; ++ args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE; ++ args.priv = nullptr; ++ args.buffer = buffer_.get(); ++ ++ pjrt::LogFatalIfPjrtError( ++ pjrt_c_api()->PJRT_Buffer_DynamicDimensionIndices(&args), pjrt_c_api()); ++ ++ absl::InlinedVector dynamic_dimensions(dimensions().size()); ++ for (int i = 0; i < args.num_dynamic_dims; ++i) { ++ dynamic_dimensions[args.dynamic_dim_indices[i]] = true; ++ } ++ ++ return dynamic_dimensions; ++} ++ + StatusOr> PjRtCApiBuffer::logical_dimensions() { + PJRT_Buffer_UnpaddedDimensions_Args args; + args.struct_size = PJRT_Buffer_UnpaddedDimensions_Args_STRUCT_SIZE; +diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h +index b2e2de349..2687b5371 100644 +--- a/xla/pjrt/pjrt_c_api_client.h ++++ b/xla/pjrt/pjrt_c_api_client.h +@@ -379,11 +379,7 @@ class PjRtCApiBuffer : public PjRtBuffer { + + bool has_dynamic_dimensions() const override; + +- absl::Span is_dynamic_dimension() const override { +- LOG(FATAL) << "PjRtCApiBuffer::is_dynamic_dimension() not implemented. " +- << "Considering using has_dynamic_dimensions() or " +- "logical_dimensions() if applicable."; +- } ++ absl::Span is_dynamic_dimension() const override; + + StatusOr> logical_dimensions() override; + diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 4bc35eb58415..8ad613520865 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -71,23 +71,6 @@ std::unordered_map build_index_map( return device_index; } -// TODO: Do we care about layout here? -xla::Shape on_device_shape(xla::PjRtBuffer* buffer) { - auto dimensions = buffer->dimensions(); - auto size = dimensions.size(); - // TODO: use method when implemented - std::vector dynamic_dimensions(size); - if (buffer->has_dynamic_dimensions()) { - auto logical_dimensions = buffer->logical_dimensions().value(); - for (int i = 0; i < size; ++i) { - dynamic_dimensions[i] = dimensions[i] != logical_dimensions[i]; - } - } - - return xla::ShapeUtil::MakeShape(buffer->element_type(), dimensions, - dynamic_dimensions); -} - // Builds the xla::Shape of the output xla::Literal on the host. xla::Shape host_output_shape(xla::PjRtBuffer* buffer) { xla::Shape shape = xla::ShapeUtil::MakeShape( @@ -595,8 +578,8 @@ PjRtComputationClient::ExecuteComputation( for (auto& result : results) { std::unique_ptr buffer = std::move(result); - std::shared_ptr data = std::make_shared( - device, on_device_shape(buffer.get()), std::move(buffer)); + std::shared_ptr data = + std::make_shared(device, std::move(buffer)); datas.push_back(data); } @@ -723,8 +706,8 @@ PjRtComputationClient::ExecuteReplicated( << "Exepcted device: " << pjrt_device->DebugString() << " vs. actual device: " << buffer->device()->DebugString(); - std::shared_ptr data = std::make_shared( - devices[i], on_device_shape(buffer.get()), std::move(buffer)); + std::shared_ptr data = + std::make_shared(devices[i], std::move(buffer)); datas.push_back(data); } data_handles.push_back(datas); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index c78e03329522..8384c554f8aa 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -164,6 +164,12 @@ class PjRtComputationClient : public ComputationClient { std::shared_ptr buffer) : Data(std::move(device), std::move(device_shape)), buffer(buffer) {} + PjRtData(std::string device, std::shared_ptr buffer) + : Data(std::move(device), + xla::Shape(buffer->element_type(), buffer->dimensions(), + buffer->is_dynamic_dimension(), {})), + buffer(buffer) {} + OpaqueHandle GetOpaqueHandle() override { XLA_CHECK(HasValue()) << "buffer with shape " << shape().ToString() << " on device "