Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transfer data directly to the device #5772

Merged
merged 12 commits into from
Nov 13, 2023

Conversation

will-cromar
Copy link
Collaborator

@will-cromar will-cromar commented Nov 6, 2023

Take two. See the notes from the original PR #5752

New changes:

  • Use the actual byte strides of the input at::Tensor instead of the strides in the xla::Shape.
  • Downcast the input at::Tensor if the target type differs from the actual type.
  • Release the GIL in XlaDataToTensors, since at::Tensor destruction after TransferToServer can deadlock with TransferFromServer.

@will-cromar
Copy link
Collaborator Author

will-cromar commented Nov 6, 2023

The issue here is that I was calculating the byte_strides for an at::Tensor based on the given xla::Shape, which may not match. There's still an unresolved issue here that PjRtClient::BufferFromHostBuffer relies on the target xla::PrimitiveType to determine the size of each element in the source.

In the numerous special cases in GetDevicePrimitiveType where we silently modify types, this means that BufferFromHostBuffer will read the wrong length. For example, it will read the first 32 bits of a float64 on TPU, since we silently downcast float64 to float32. Some of these cases are now irrelevant, but I'm not sure that all of them are. cc @JackCaoG to weigh in.

In cases where the source type does not match the output type, I believe we'll still have to "stage" the data in an xla::Literal or cast the at::Tensor. The conversion has to happen somewhere, but I don't think we can do it just by modifying BufferFromHostBuffer's data pointer and byte_strides.

@will-cromar
Copy link
Collaborator Author

I have everything working locally now. Separating the more tedious changes here into #5777.

After this PR, we'll still have to make an intermediate copy if the input tensor type does not match the target type. I added a counter to capture this overhead, since it may have a performance impact. For what it's worth, casting the at::Tensor directly still seems to be faster than copying to an xla::Literal.

Getting around the copy is simple: just create the tensors such that the type matches what it will be on the device. So if you want a bf16 on device, make a bf16 tensor.

@will-cromar will-cromar force-pushed the wcromar/transfer-to-device-again branch from 13035ed to df239db Compare November 8, 2023 17:41
@will-cromar
Copy link
Collaborator Author

The TPU CI is currently hanging on SPMDLoadPlannerTest.test_resolve_and_commit_sharded_tensor. I can reproduce this locally, but not if I step through the lines with pdb. That suggests it may be hanging due to a race condition. This test does not look particularly related to this PR.

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.
@will-cromar will-cromar force-pushed the wcromar/transfer-to-device-again branch from 30abda8 to 62cf72d Compare November 9, 2023 01:46
@will-cromar will-cromar marked this pull request as ready for review November 9, 2023 17:00
@will-cromar will-cromar added DO_NOT_MERGE_YET For PRs which cannot be merged, despite tests passing runtime labels Nov 9, 2023
@will-cromar
Copy link
Collaborator Author

@jonb377 and I were able to figure out where the deadlock is. The hang is caused by a GIL deadlock when we try to retrieve data from the device before a transfer finishes, and the transfer is the only thing keeping an at::Tensor alive. If anything, I'm surprised only one test case caught this bug. Marking this PR as "do not merge" until we find a fix.

Here's what's happening:

  1. The main thread moves some data to the TPU. This immediately kicks off an async thread to handle the transfer (which is kind of the point of the whole PR).
  2. The main thread continues to call local_shards, which calls down to XlaDataToTensors, which calls TransferToServer, which has to block until the transfer from (1) is complete. Note that the main thread is holding the GIL.
  3. The async transfer thread from (1) kicks off, holding a callback that holds the only shared_ptr to an AtenSource, which holds an at::Tensor. This is required to ensure that the at::Tensor is kept alive until the transfer is complete.
  4. The actual transfer of data from host to device is complete, firing the callback that holds the only shared_ptr to the AtenSource.
  5. The callback functionally does nothing except delete the shared_ptr<AtenSource>. Because it is the last owner of the underlying at::Tensor, this destroys the at::Tensor, which destroys the c10::TensorImpl.
  6. c10::TensorImpl's destructor waits to acquire the GIL.
  7. Since TransferToServer fires the callback synchronously, it cannot return until the GIL is acquired, even though the actual transfer to the TPU is complete.
  8. See 2.

Here's the relevant stack trace through TransferToServer from gdb:

...
#4  take_gil (ceval=0x7f210c228048 <_PyRuntime+584>, tstate=0x7f1a40002bb0) at Python/ceval_gil.h:206
#5  0x00007f210c0c3655 in PyEval_AcquireThread (tstate=0x7f1a40002bb0) at Python/ceval.c:316
#6  0x00007f2107849f05 in pybind11::gil_scoped_acquire::gil_scoped_acquire() () from /usr/local/lib/python3.8/site-packages/torch/lib/libtorch_python.so
...
#8  0x00007f21073f932b in c10::impl::PyObjectSlot::maybe_destroy_pyobj() () from /usr/local/lib/python3.8/site-packages/torch/lib/libc10.so
...
#11 0x00007f21073e8669 in c10::TensorImpl::~TensorImpl() () from /usr/local/lib/python3.8/site-packages/torch/lib/libc10.so
#12 0x00007f1e55ad83b3 in std::_Sp_counted_ptr_inplace<torch_xla::runtime::AtenSource, std::allocator<torch_xla::runtime::AtenSource>, (__gnu_cxx::_Lock_policy)2>::_M_dispose() () from /usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git30abda8-py3.8-linux-x86_64.egg/_XLAC.cpython-38-x86_64-linux-gnu.so
#13 0x00007f1e55935b3a in std::_Sp_counted_base<(__gnu_cxx::_Lock_policy)2>::_M_release() () from /usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git30abda8-py3.8-linux-x86_64.egg/_XLAC.cpython-38-x86_64-linux-gnu.so
#14 0x00007f1e55d9d3c4 in std::_Function_handler<void (), torch_xla::runtime::PjRtComputationClient::TransferToServer(absl::lts_20230125::Span<std::shared_ptr<torch_xla::runtime::TensorSource const> const>)::{lambda()#1}>::_M_manager(std::_Any_data&, std::_Any_data const&, std::_Manager_operation) () from /usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git30abda8-py3.8-linux-x86_64.egg/_XLAC.cpython-38-x86_64-linux-gnu.so
#15 0x00007f1e5a81ef92 in std::_Function_handler<void (PJRT_Error*), xla::PjRtCApiClient::BufferFromHostBufferInternalImpl(void const*, xla::PrimitiveType, absl::lts_20230125::Span<long const>, std::optional<absl::lts_20230125::Span<long const> >, xla::PjRtClient::HostBufferSemantics, std::function<void ()>, std::variant<xla::PjRtDevice*, xla::PjRtMemorySpace*>, xla::Layout const*)::{lambda(PJRT_Error*)#1}>::_M_manager(std::_Any_data&, std::_Any_data const&, std::_Manager_operation) () from /usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git30abda8-py3.8-linux-x86_64.egg/_XLAC.cpython-38-x86_64-linux-gnu.so

@will-cromar
Copy link
Collaborator Author

I fixed a similar GIL deadlock bug about a year ago in #4504. In that case, the solution was to release the GIL during TransferFromServer, but it looks like I inserted that fix into a different code path than SPMD uses.

@will-cromar
Copy link
Collaborator Author

will-cromar commented Nov 9, 2023

I wrapped the GIL release and data transfer into a new utility, ReleaseGilAndTransferData. GetTensors and XlaDataToTensors can both call this function. I also added a note to computation_client.h to be wary of the GIL when calling TransferFromServer, since it's the only call to PJRT that is blocking.

I'm open to better names for ReleaseGilAndTransferData.

@will-cromar will-cromar requested a review from yeounoh November 9, 2023 21:56
@will-cromar will-cromar removed the DO_NOT_MERGE_YET For PRs which cannot be merged, despite tests passing label Nov 10, 2023
@@ -37,17 +37,17 @@ class PjRtComputationClient : public ComputationClient {
std::optional<xla::OpSharding> GetDataSharding(DataPtr handle) override;

std::vector<DataPtr> TransferToServer(
absl::Span<const TensorSource> tensors) override;
absl::Span<const std::shared_ptr<const TensorSource>> tensors) override;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we require the caller to manage the ownership & lifetime of TensorSource*, i.e., const TensorSource* instead, or it's necessary to ensure that the memory is held during the client ops and in the client?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. PJRT lets us tie the lifetime of an object to an operation by capturing it in a callback. std::functions have to be copyable, so shared_ptr is our best choice here. TensorSource itself may be expensive or impossible to copy.

The caller of TransferToServer will be much shorter-lived than the actual transfer, so ownership should pass down. We could tighten up the interface here and consume a unique_ptr since we only need copyability within the implementation of TransferToServer. What do you think?

Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@JackCaoG
Copy link
Collaborator

I will try to take a look today

@will-cromar will-cromar merged commit 05a3cdd into master Nov 13, 2023
18 checks passed
mbzomowski pushed a commit to mbzomowski-test-org/xla that referenced this pull request Nov 16, 2023
* Transfer data directly to the device (pytorch#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
zpcore pushed a commit that referenced this pull request Nov 21, 2023
* Transfer data directly to the device (#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
lsy323 pushed a commit to lsy323/xla that referenced this pull request Nov 28, 2023
* Transfer data directly to the device (pytorch#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
ManfeiBai pushed a commit that referenced this pull request Nov 29, 2023
* Transfer data directly to the device (#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
ManfeiBai pushed a commit that referenced this pull request Nov 29, 2023
* Transfer data directly to the device (#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
ManfeiBai pushed a commit that referenced this pull request Nov 29, 2023
* Transfer data directly to the device (#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
* Transfer data directly to the device (pytorch#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
* Transfer data directly to the device (#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
* Transfer data directly to the device (#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants