Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove on_device_shape and logical_on_device_shape #5546

Merged
merged 6 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ http_archive(
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:constexpr_return.diff",
"//openxla_patches:pjrt_c_api_dynamic_dimensions.diff",
],
strip_prefix = "xla-7a371ed44aba34f83d6d3d1159d2e6d0d327c603",
urls = [
Expand Down
76 changes: 76 additions & 0 deletions openxla_patches/pjrt_c_api_dynamic_dimensions.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Partial backport of 6308dba2903e78961ac4122f361bc91b09f36891. Remove in next
# pin update.
diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc
index ef0b6686c..c0341e81e 100644
--- a/xla/pjrt/pjrt_c_api_client.cc
+++ b/xla/pjrt/pjrt_c_api_client.cc
@@ -1584,6 +1584,34 @@ bool PjRtCApiBuffer::has_dynamic_dimensions() const {
return args.num_dynamic_dims > 0;
}

+absl::Span<const bool> PjRtCApiBuffer::is_dynamic_dimension() const {
+ {
+ absl::MutexLock lock(&mu_);
+ if (!is_dynamic_dimension_.has_value()) {
+ absl::InlinedVector<bool, InlineRank()>& is_dynamic_dimension_value =
+ is_dynamic_dimension_.emplace();
+ is_dynamic_dimension_value.assign(dimensions().size(), false);
+
+ PJRT_Buffer_DynamicDimensionIndices_Args args;
+ args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE;
+ args.priv = nullptr;
+ args.buffer = buffer_.get();
+ const PJRT_Api* api = pjrt_c_api();
+ std::unique_ptr<PJRT_Error, pjrt::PJRT_ErrorDeleter> error(
+ api->PJRT_Buffer_DynamicDimensionIndices(&args),
+ pjrt::MakeErrorDeleter(api));
+ if (error && pjrt::GetErrorCode(error.get(), api) ==
+ PJRT_Error_Code_UNIMPLEMENTED) {
+ return *is_dynamic_dimension_;
+ }
+ for (int i = 0; i < args.num_dynamic_dims; ++i) {
+ is_dynamic_dimension_value[args.dynamic_dim_indices[i]] = true;
+ }
+ }
+ }
+ return *is_dynamic_dimension_;
+}
+
StatusOr<std::vector<int64_t>> PjRtCApiBuffer::logical_dimensions() {
PJRT_Buffer_UnpaddedDimensions_Args args;
args.struct_size = PJRT_Buffer_UnpaddedDimensions_Args_STRUCT_SIZE;
diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h
index 9c460f246..279608e60 100644
--- a/xla/pjrt/pjrt_c_api_client.h
+++ b/xla/pjrt/pjrt_c_api_client.h
@@ -27,6 +27,7 @@ limitations under the License.
#include <vector>

#include "absl/container/flat_hash_map.h"
+#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
@@ -369,11 +370,7 @@ class PjRtCApiBuffer : public PjRtBuffer {

bool has_dynamic_dimensions() const override;

- absl::Span<const bool> is_dynamic_dimension() const override {
- LOG(FATAL) << "PjRtCApiBuffer::is_dynamic_dimension() not implemented. "
- << "Considering using has_dynamic_dimensions() or "
- "logical_dimensions() if applicable.";
- }
+ absl::Span<const bool> is_dynamic_dimension() const override;

StatusOr<std::vector<int64_t>> logical_dimensions() override;

@@ -455,6 +452,9 @@ class PjRtCApiBuffer : public PjRtBuffer {
std::shared_ptr<PjRtFuture<Status>::Promise> readiness_promise_;
// Set and cached the first time layout() is called.
mutable std::optional<xla::Layout> layout_;
+ // Set and cached the first time is_dynamic_dimension() is called.
+ mutable std::optional<absl::InlinedVector<bool, InlineRank()>>
+ is_dynamic_dimension_;
// Used to synchronize concurrent setting of cached values.
mutable absl::Mutex mu_;
};
22 changes: 15 additions & 7 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ std::unordered_map<int, int> build_index_map(
return device_index;
}

// Builds the xla::Shape of the output xla::Literal on the host.
xla::Shape host_output_shape(xla::PjRtBuffer* buffer) {
xla::Shape shape = xla::ShapeUtil::MakeShape(
buffer->element_type(), buffer->logical_dimensions().value());
*shape.mutable_layout() = buffer->layout();

return xla::ShapeUtil::DeviceShapeToHostShape(shape);
}

} // namespace

std::string PjRtComputationClient::PjRtDeviceToString(
Expand Down Expand Up @@ -424,9 +433,8 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromServer(
auto new_handle = ReplicateShardedData(handle);
const PjRtData& pjrt_data = dynamic_cast<const PjRtData&>(*new_handle);

xla::Shape target_shape = xla::ShapeUtil::DeviceShapeToHostShape(
pjrt_data.buffer->logical_on_device_shape().value());
auto& literal = literals.emplace_back(target_shape);
auto& literal =
literals.emplace_back(host_output_shape(pjrt_data.buffer.get()));
XLA_CHECK_OK(pjrt_data.buffer->ToLiteralSync(&literal));

total_size += literal.size_bytes();
Expand Down Expand Up @@ -569,8 +577,8 @@ PjRtComputationClient::ExecuteComputation(
for (auto& result : results) {
std::unique_ptr<xla::PjRtBuffer> buffer = std::move(result);

std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(
device, buffer->on_device_shape(), std::move(buffer));
std::shared_ptr<PjRtData> data =
std::make_shared<PjRtData>(device, std::move(buffer));

datas.push_back(data);
}
Expand Down Expand Up @@ -697,8 +705,8 @@ PjRtComputationClient::ExecuteReplicated(
<< "Exepcted device: " << pjrt_device->DebugString()
<< " vs. actual device: " << buffer->device()->DebugString();

std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(
devices[i], buffer->on_device_shape(), std::move(buffer));
std::shared_ptr<PjRtData> data =
std::make_shared<PjRtData>(devices[i], std::move(buffer));
datas.push_back(data);
}
data_handles.push_back(datas);
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ class PjRtComputationClient : public ComputationClient {
std::shared_ptr<xla::PjRtBuffer> buffer)
: Data(std::move(device), std::move(device_shape)), buffer(buffer) {}

PjRtData(std::string device, std::shared_ptr<xla::PjRtBuffer> buffer)
: Data(std::move(device),
xla::Shape(buffer->element_type(), buffer->dimensions(),
buffer->is_dynamic_dimension(), {})),
buffer(buffer) {}

OpaqueHandle GetOpaqueHandle() override {
XLA_CHECK(HasValue())
<< "buffer with shape " << shape().ToString() << " on device "
Expand Down