-
Notifications
You must be signed in to change notification settings - Fork 487
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6ff39fa
commit 984b303
Showing
4 changed files
with
57 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc | ||
index 565aa2208..2278ab6c4 100644 | ||
--- a/xla/pjrt/pjrt_c_api_client.cc | ||
+++ b/xla/pjrt/pjrt_c_api_client.cc | ||
@@ -1658,6 +1658,24 @@ bool PjRtCApiBuffer::has_dynamic_dimensions() const { | ||
return args.num_dynamic_dims > 0; | ||
} | ||
|
||
+ | ||
+absl::Span<const bool> PjRtCApiBuffer::is_dynamic_dimension() const { | ||
+ PJRT_Buffer_DynamicDimensionIndices_Args args; | ||
+ args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE; | ||
+ args.priv = nullptr; | ||
+ args.buffer = buffer_.get(); | ||
+ | ||
+ pjrt::LogFatalIfPjrtError( | ||
+ pjrt_c_api()->PJRT_Buffer_DynamicDimensionIndices(&args), pjrt_c_api()); | ||
+ | ||
+ absl::InlinedVector<bool, 4> dynamic_dimensions(dimensions().size()); | ||
+ for (int i = 0; i < args.num_dynamic_dims; ++i) { | ||
+ dynamic_dimensions[args.dynamic_dim_indices[i]] = true; | ||
+ } | ||
+ | ||
+ return dynamic_dimensions; | ||
+} | ||
+ | ||
StatusOr<std::vector<int64_t>> PjRtCApiBuffer::logical_dimensions() { | ||
PJRT_Buffer_UnpaddedDimensions_Args args; | ||
args.struct_size = PJRT_Buffer_UnpaddedDimensions_Args_STRUCT_SIZE; | ||
diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h | ||
index b2e2de349..2687b5371 100644 | ||
--- a/xla/pjrt/pjrt_c_api_client.h | ||
+++ b/xla/pjrt/pjrt_c_api_client.h | ||
@@ -379,11 +379,7 @@ class PjRtCApiBuffer : public PjRtBuffer { | ||
|
||
bool has_dynamic_dimensions() const override; | ||
|
||
- absl::Span<const bool> is_dynamic_dimension() const override { | ||
- LOG(FATAL) << "PjRtCApiBuffer::is_dynamic_dimension() not implemented. " | ||
- << "Considering using has_dynamic_dimensions() or " | ||
- "logical_dimensions() if applicable."; | ||
- } | ||
+ absl::Span<const bool> is_dynamic_dimension() const override; | ||
|
||
StatusOr<std::vector<int64_t>> logical_dimensions() override; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters