Skip to content

Commit

Permalink
Patch in is_dynamic_dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Sep 7, 2023
1 parent 6ff39fa commit 984b303
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 21 deletions.
1 change: 1 addition & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ http_archive(
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:constexpr_return.diff",
"//openxla_patches:pjrt_c_api_dynamic_dimensions.diff",
],
strip_prefix = "xla-cd2cf5c34931e4fc1cacf83bfc480a5b93f05f6d",
urls = [
Expand Down
46 changes: 46 additions & 0 deletions openxla_patches/pjrt_c_api_dynamic_dimensions.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc
index 565aa2208..2278ab6c4 100644
--- a/xla/pjrt/pjrt_c_api_client.cc
+++ b/xla/pjrt/pjrt_c_api_client.cc
@@ -1658,6 +1658,24 @@ bool PjRtCApiBuffer::has_dynamic_dimensions() const {
return args.num_dynamic_dims > 0;
}

+
+absl::Span<const bool> PjRtCApiBuffer::is_dynamic_dimension() const {
+ PJRT_Buffer_DynamicDimensionIndices_Args args;
+ args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE;
+ args.priv = nullptr;
+ args.buffer = buffer_.get();
+
+ pjrt::LogFatalIfPjrtError(
+ pjrt_c_api()->PJRT_Buffer_DynamicDimensionIndices(&args), pjrt_c_api());
+
+ absl::InlinedVector<bool, 4> dynamic_dimensions(dimensions().size());
+ for (int i = 0; i < args.num_dynamic_dims; ++i) {
+ dynamic_dimensions[args.dynamic_dim_indices[i]] = true;
+ }
+
+ return dynamic_dimensions;
+}
+
StatusOr<std::vector<int64_t>> PjRtCApiBuffer::logical_dimensions() {
PJRT_Buffer_UnpaddedDimensions_Args args;
args.struct_size = PJRT_Buffer_UnpaddedDimensions_Args_STRUCT_SIZE;
diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h
index b2e2de349..2687b5371 100644
--- a/xla/pjrt/pjrt_c_api_client.h
+++ b/xla/pjrt/pjrt_c_api_client.h
@@ -379,11 +379,7 @@ class PjRtCApiBuffer : public PjRtBuffer {

bool has_dynamic_dimensions() const override;

- absl::Span<const bool> is_dynamic_dimension() const override {
- LOG(FATAL) << "PjRtCApiBuffer::is_dynamic_dimension() not implemented. "
- << "Considering using has_dynamic_dimensions() or "
- "logical_dimensions() if applicable.";
- }
+ absl::Span<const bool> is_dynamic_dimension() const override;

StatusOr<std::vector<int64_t>> logical_dimensions() override;

25 changes: 4 additions & 21 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,23 +71,6 @@ std::unordered_map<int, int> build_index_map(
return device_index;
}

// TODO: Do we care about layout here?
xla::Shape on_device_shape(xla::PjRtBuffer* buffer) {
auto dimensions = buffer->dimensions();
auto size = dimensions.size();
// TODO: use method when implemented
std::vector<bool> dynamic_dimensions(size);
if (buffer->has_dynamic_dimensions()) {
auto logical_dimensions = buffer->logical_dimensions().value();
for (int i = 0; i < size; ++i) {
dynamic_dimensions[i] = dimensions[i] != logical_dimensions[i];
}
}

return xla::ShapeUtil::MakeShape(buffer->element_type(), dimensions,
dynamic_dimensions);
}

// Builds the xla::Shape of the output xla::Literal on the host.
xla::Shape host_output_shape(xla::PjRtBuffer* buffer) {
xla::Shape shape = xla::ShapeUtil::MakeShape(
Expand Down Expand Up @@ -595,8 +578,8 @@ PjRtComputationClient::ExecuteComputation(
for (auto& result : results) {
std::unique_ptr<xla::PjRtBuffer> buffer = std::move(result);

std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(
device, on_device_shape(buffer.get()), std::move(buffer));
std::shared_ptr<PjRtData> data =
std::make_shared<PjRtData>(device, std::move(buffer));

datas.push_back(data);
}
Expand Down Expand Up @@ -723,8 +706,8 @@ PjRtComputationClient::ExecuteReplicated(
<< "Exepcted device: " << pjrt_device->DebugString()
<< " vs. actual device: " << buffer->device()->DebugString();

std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(
devices[i], on_device_shape(buffer.get()), std::move(buffer));
std::shared_ptr<PjRtData> data =
std::make_shared<PjRtData>(devices[i], std::move(buffer));
datas.push_back(data);
}
data_handles.push_back(datas);
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ class PjRtComputationClient : public ComputationClient {
std::shared_ptr<xla::PjRtBuffer> buffer)
: Data(std::move(device), std::move(device_shape)), buffer(buffer) {}

PjRtData(std::string device, std::shared_ptr<xla::PjRtBuffer> buffer)
: Data(std::move(device),
xla::Shape(buffer->element_type(), buffer->dimensions(),
buffer->is_dynamic_dimension(), {})),
buffer(buffer) {}

OpaqueHandle GetOpaqueHandle() override {
XLA_CHECK(HasValue())
<< "buffer with shape " << shape().ToString() << " on device "
Expand Down

0 comments on commit 984b303

Please sign in to comment.