Skip to content

Commit

Permalink
logical_on_device_shape -> host_output_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Sep 7, 2023
1 parent 07ac0a2 commit 6ff39fa
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,27 +71,30 @@ std::unordered_map<int, int> build_index_map(
return device_index;
}

// TODO: Do we care about layout here?
xla::Shape on_device_shape(xla::PjRtBuffer* buffer) {
auto dimensions = buffer->dimensions();
auto size = dimensions.size();
absl::InlinedVector<bool, 4> dynamic_dimensions(size);
// TODO: use method when implemented
std::vector<bool> dynamic_dimensions(size);
if (buffer->has_dynamic_dimensions()) {
// TODO: use method when implemented
auto logical_dimensions = buffer->logical_dimensions().value();
for (int i = 0; i < size; ++i) {
dynamic_dimensions[i] = dimensions[i] != logical_dimensions[i];
}
}

return xla::Shape(buffer->element_type(), dimensions, dynamic_dimensions, {});
return xla::ShapeUtil::MakeShape(buffer->element_type(), dimensions,
dynamic_dimensions);
}

xla::Shape logical_on_device_shape(xla::PjRtBuffer* buffer) {
XLA_CHECK(!buffer->has_dynamic_dimensions());
auto dimensions = buffer->logical_dimensions().value();
absl::InlinedVector<bool, 4> dynamic_dimensions(dimensions.size());
// Builds the xla::Shape of the output xla::Literal on the host.
xla::Shape host_output_shape(xla::PjRtBuffer* buffer) {
xla::Shape shape = xla::ShapeUtil::MakeShape(
buffer->element_type(), buffer->logical_dimensions().value());
*shape.mutable_layout() = buffer->layout();

return xla::Shape(buffer->element_type(), dimensions, dynamic_dimensions, {});
return xla::ShapeUtil::DeviceShapeToHostShape(shape);
}

} // namespace
Expand Down Expand Up @@ -448,9 +451,8 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromServer(
auto new_handle = ReplicateShardedData(handle);
const PjRtData& pjrt_data = dynamic_cast<const PjRtData&>(*new_handle);

auto shape = logical_on_device_shape(pjrt_data.buffer.get());
*shape.mutable_layout() = pjrt_data.buffer->layout();
auto& literal = literals.emplace_back(xla::ShapeUtil::DeviceShapeToHostShape(shape));
auto& literal =
literals.emplace_back(host_output_shape(pjrt_data.buffer.get()));
XLA_CHECK_OK(pjrt_data.buffer->ToLiteralSync(&literal));

total_size += literal.size_bytes();
Expand Down

0 comments on commit 6ff39fa

Please sign in to comment.