Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed May 1, 2024
1 parent 8dcf30f commit 3d7e463
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
6 changes: 4 additions & 2 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2482,10 +2482,12 @@ void InitXlaModuleBindings(py::module m) {
return runtime::GetComputationClient()->UnsafeBufferPointer(
UnwrapXlaData(data));
} else {
XLA_ERROR() << "Could not get the buffer pointer.";
XLA_ERROR() << "Could not get the buffer pointer for XLATensor "
"with IR that's not DeviceData";
}
}
XLA_ERROR() << "Could not get the buffer pointer.";
XLA_ERROR() << "Could not get the buffer pointer for XLATensor "
"without a data handle or an IR.";
});

// -------------Dynamo Integration API Start-------------------------
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,9 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::ReshardData(

std::uintptr_t PjRtComputationClient::UnsafeBufferPointer(
const DataPtr handle) {
std::shared_ptr<PjRtData> pjrt_data = ReplicateShardedData(handle);
XLA_CHECK(pjrt_data);
std::shared_ptr<PjRtData> pjrt_data =
std::dynamic_pointer_cast<PjRtData>(handle) XLA_CHECK(pjrt_data)
<< "handle must be PjRtData, got " << handle->ToString();
xla::StatusOr<std::uintptr_t> ptr =
client_->UnsafeBufferPointer(pjrt_data->buffer.get());
XLA_CHECK(ptr.ok());
Expand Down

0 comments on commit 3d7e463

Please sign in to comment.