Skip to content

Commit

Permalink
Update patch with real pending change
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Sep 11, 2023
1 parent 3174b9d commit e180883
Showing 1 changed file with 45 additions and 17 deletions.
62 changes: 45 additions & 17 deletions openxla_patches/pjrt_c_api_dynamic_dimensions.diff
Original file line number Diff line number Diff line change
@@ -1,37 +1,55 @@
diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc
index 565aa2208..2278ab6c4 100644
index d45b9377c..92d244750 100644
--- a/xla/pjrt/pjrt_c_api_client.cc
+++ b/xla/pjrt/pjrt_c_api_client.cc
@@ -1658,6 +1658,24 @@ bool PjRtCApiBuffer::has_dynamic_dimensions() const {
@@ -1489,6 +1489,34 @@ bool PjRtCApiBuffer::has_dynamic_dimensions() const {
return args.num_dynamic_dims > 0;
}

+
+absl::Span<const bool> PjRtCApiBuffer::is_dynamic_dimension() const {
+ PJRT_Buffer_DynamicDimensionIndices_Args args;
+ args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE;
+ args.priv = nullptr;
+ args.buffer = buffer_.get();
+
+ pjrt::LogFatalIfPjrtError(
+ pjrt_c_api()->PJRT_Buffer_DynamicDimensionIndices(&args), pjrt_c_api());
+ {
+ absl::MutexLock lock(&mu_);
+ if (!is_dynamic_dimension_.has_value()) {
+ absl::InlinedVector<bool, InlineRank()>& is_dynamic_dimension_value =
+ is_dynamic_dimension_.emplace();
+ is_dynamic_dimension_value.assign(dimensions().size(), false);
+
+ absl::InlinedVector<bool, 4> dynamic_dimensions(dimensions().size());
+ for (int i = 0; i < args.num_dynamic_dims; ++i) {
+ dynamic_dimensions[args.dynamic_dim_indices[i]] = true;
+ PJRT_Buffer_DynamicDimensionIndices_Args args;
+ args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE;
+ args.priv = nullptr;
+ args.buffer = buffer_.get();
+ const PJRT_Api* api = pjrt_c_api();
+ std::unique_ptr<PJRT_Error, pjrt::PJRT_ErrorDeleter> error(
+ api->PJRT_Buffer_DynamicDimensionIndices(&args),
+ pjrt::MakeErrorDeleter(api));
+ if (error && pjrt::GetErrorCode(error.get(), api) ==
+ PJRT_Error_Code_UNIMPLEMENTED) {
+ return *is_dynamic_dimension_;
+ }
+ for (int i = 0; i < args.num_dynamic_dims; ++i) {
+ is_dynamic_dimension_value[args.dynamic_dim_indices[i]] = true;
+ }
+ }
+ }
+
+ return dynamic_dimensions;
+ return *is_dynamic_dimension_;
+}
+
StatusOr<std::vector<int64_t>> PjRtCApiBuffer::logical_dimensions() {
PJRT_Buffer_UnpaddedDimensions_Args args;
args.struct_size = PJRT_Buffer_UnpaddedDimensions_Args_STRUCT_SIZE;
diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h
index b2e2de349..2687b5371 100644
index b9597f923..5101b09a9 100644
--- a/xla/pjrt/pjrt_c_api_client.h
+++ b/xla/pjrt/pjrt_c_api_client.h
@@ -379,11 +379,7 @@ class PjRtCApiBuffer : public PjRtBuffer {
@@ -23,6 +23,7 @@ limitations under the License.
#include <utility>
#include <vector>

+#include "absl/container/inlined_vector.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
#include "xla/pjrt/pjrt_client.h"
@@ -320,11 +321,7 @@ class PjRtCApiBuffer : public PjRtBuffer {

bool has_dynamic_dimensions() const override;

Expand All @@ -44,3 +62,13 @@ index b2e2de349..2687b5371 100644

StatusOr<std::vector<int64_t>> logical_dimensions() override;

@@ -407,6 +404,9 @@ class PjRtCApiBuffer : public PjRtBuffer {
std::shared_ptr<PjRtFuture<Status>::Promise> readiness_promise_;
// Set and cached the first time layout() is called.
mutable std::optional<xla::Layout> layout_;
+ // Set and cached the first time is_dynamic_dimension() is called.
+ mutable std::optional<absl::InlinedVector<bool, InlineRank()>>
+ is_dynamic_dimension_;
// Used to synchronize concurrent setting of cached values.
mutable absl::Mutex mu_;
};

0 comments on commit e180883

Please sign in to comment.