Skip to content

Commit

Permalink
Remove on_device_shape and logical_on_device_shape (#5546)
Browse files Browse the repository at this point in the history
* [WIP] Remove `on_device_shape` and `logical_on_device_shape`

* logical_on_device_shape -> host_output_shape

* Patch in `is_dynamic_dimension`

* Update patch with real pending change

* Update patch

* add commit hash
  • Loading branch information
will-cromar authored Sep 19, 2023
1 parent a4874e2 commit dbb92c9
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 7 deletions.
1 change: 1 addition & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ http_archive(
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:constexpr_return.diff",
"//openxla_patches:pjrt_c_api_dynamic_dimensions.diff",
],
strip_prefix = "xla-97a5f819faf9ff793b7ba68ff1f31f74f9459c18",
urls = [
Expand Down
76 changes: 76 additions & 0 deletions openxla_patches/pjrt_c_api_dynamic_dimensions.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Partial backport of 6308dba2903e78961ac4122f361bc91b09f36891. Remove in next
# pin update.
diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc
index ef0b6686c..c0341e81e 100644
--- a/xla/pjrt/pjrt_c_api_client.cc
+++ b/xla/pjrt/pjrt_c_api_client.cc
@@ -1584,6 +1584,34 @@ bool PjRtCApiBuffer::has_dynamic_dimensions() const {
return args.num_dynamic_dims > 0;
}

+absl::Span<const bool> PjRtCApiBuffer::is_dynamic_dimension() const {
+ {
+ absl::MutexLock lock(&mu_);
+ if (!is_dynamic_dimension_.has_value()) {
+ absl::InlinedVector<bool, InlineRank()>& is_dynamic_dimension_value =
+ is_dynamic_dimension_.emplace();
+ is_dynamic_dimension_value.assign(dimensions().size(), false);
+
+ PJRT_Buffer_DynamicDimensionIndices_Args args;
+ args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE;
+ args.priv = nullptr;
+ args.buffer = buffer_.get();
+ const PJRT_Api* api = pjrt_c_api();
+ std::unique_ptr<PJRT_Error, pjrt::PJRT_ErrorDeleter> error(
+ api->PJRT_Buffer_DynamicDimensionIndices(&args),
+ pjrt::MakeErrorDeleter(api));
+ if (error && pjrt::GetErrorCode(error.get(), api) ==
+ PJRT_Error_Code_UNIMPLEMENTED) {
+ return *is_dynamic_dimension_;
+ }
+ for (int i = 0; i < args.num_dynamic_dims; ++i) {
+ is_dynamic_dimension_value[args.dynamic_dim_indices[i]] = true;
+ }
+ }
+ }
+ return *is_dynamic_dimension_;
+}
+
StatusOr<std::vector<int64_t>> PjRtCApiBuffer::logical_dimensions() {
PJRT_Buffer_UnpaddedDimensions_Args args;
args.struct_size = PJRT_Buffer_UnpaddedDimensions_Args_STRUCT_SIZE;
diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h
index 9c460f246..279608e60 100644
--- a/xla/pjrt/pjrt_c_api_client.h
+++ b/xla/pjrt/pjrt_c_api_client.h
@@ -27,6 +27,7 @@ limitations under the License.
#include <vector>

#include "absl/container/flat_hash_map.h"
+#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
@@ -369,11 +370,7 @@ class PjRtCApiBuffer : public PjRtBuffer {

bool has_dynamic_dimensions() const override;

- absl::Span<const bool> is_dynamic_dimension() const override {
- LOG(FATAL) << "PjRtCApiBuffer::is_dynamic_dimension() not implemented. "
- << "Considering using has_dynamic_dimensions() or "
- "logical_dimensions() if applicable.";
- }
+ absl::Span<const bool> is_dynamic_dimension() const override;

StatusOr<std::vector<int64_t>> logical_dimensions() override;

@@ -455,6 +452,9 @@ class PjRtCApiBuffer : public PjRtBuffer {
std::shared_ptr<PjRtFuture<Status>::Promise> readiness_promise_;
// Set and cached the first time layout() is called.
mutable std::optional<xla::Layout> layout_;
+ // Set and cached the first time is_dynamic_dimension() is called.
+ mutable std::optional<absl::InlinedVector<bool, InlineRank()>>
+ is_dynamic_dimension_;
// Used to synchronize concurrent setting of cached values.
mutable absl::Mutex mu_;
};
22 changes: 15 additions & 7 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ std::unordered_map<int, int> build_index_map(
return device_index;
}

// Builds the xla::Shape of the output xla::Literal on the host.
xla::Shape host_output_shape(xla::PjRtBuffer* buffer) {
xla::Shape shape = xla::ShapeUtil::MakeShape(
buffer->element_type(), buffer->logical_dimensions().value());
*shape.mutable_layout() = buffer->layout();

return xla::ShapeUtil::DeviceShapeToHostShape(shape);
}

} // namespace

std::string PjRtComputationClient::PjRtDeviceToString(
Expand Down Expand Up @@ -424,9 +433,8 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromServer(
auto new_handle = ReplicateShardedData(handle);
const PjRtData& pjrt_data = dynamic_cast<const PjRtData&>(*new_handle);

xla::Shape target_shape = xla::ShapeUtil::DeviceShapeToHostShape(
pjrt_data.buffer->logical_on_device_shape().value());
auto& literal = literals.emplace_back(target_shape);
auto& literal =
literals.emplace_back(host_output_shape(pjrt_data.buffer.get()));
XLA_CHECK_OK(pjrt_data.buffer->ToLiteralSync(&literal));

total_size += literal.size_bytes();
Expand Down Expand Up @@ -569,8 +577,8 @@ PjRtComputationClient::ExecuteComputation(
for (auto& result : results) {
std::unique_ptr<xla::PjRtBuffer> buffer = std::move(result);

std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(
device, buffer->on_device_shape(), std::move(buffer));
std::shared_ptr<PjRtData> data =
std::make_shared<PjRtData>(device, std::move(buffer));

datas.push_back(data);
}
Expand Down Expand Up @@ -697,8 +705,8 @@ PjRtComputationClient::ExecuteReplicated(
<< "Exepcted device: " << pjrt_device->DebugString()
<< " vs. actual device: " << buffer->device()->DebugString();

std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(
devices[i], buffer->on_device_shape(), std::move(buffer));
std::shared_ptr<PjRtData> data =
std::make_shared<PjRtData>(devices[i], std::move(buffer));
datas.push_back(data);
}
data_handles.push_back(datas);
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ class PjRtComputationClient : public ComputationClient {
std::shared_ptr<xla::PjRtBuffer> buffer)
: Data(std::move(device), std::move(device_shape)), buffer(buffer) {}

PjRtData(std::string device, std::shared_ptr<xla::PjRtBuffer> buffer)
: Data(std::move(device),
xla::Shape(buffer->element_type(), buffer->dimensions(),
buffer->is_dynamic_dimension(), {})),
buffer(buffer) {}

OpaqueHandle GetOpaqueHandle() override {
XLA_CHECK(HasValue())
<< "buffer with shape " << shape().ToString() << " on device "
Expand Down

0 comments on commit dbb92c9

Please sign in to comment.