-
Notifications
You must be signed in to change notification settings - Fork 489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Copy input tensors before async transfer #5830
Conversation
some test seems to fail |
This seems related to my changes in |
} | ||
tensor_ = std::move(tensor.to(target_torch_type, /*non_blocking=*/false, | ||
/*copy=*/true, at::MemoryFormat::Contiguous)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thinking.. we have 2 options, either copy the tensor on CPU or only return the control to python after we started the transfer(or finish the transfer? hard for me to tell at which stage the origional tensor is not needed).
If I understand correctly instead of creating a XLA:Literal, now we perform a copy on the cpu tensor, and copying the cpu tensor is faster?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thinking.. we have 2 options, either copy the tensor on CPU or only return the control to python after we started the transfer(or finish the transfer? hard for me to tell at which stage the origional tensor is not needed).
Correct. We don't have any tools to "lock" the input tensor, so we should either block until the transfer is done or we make a CPU copy here (which is also blocking, but faster than the CPU -> TPU copy).
The third option is to let the caller decide via the non_blocking
argument. Whether the CPU tensor will ever be modified during transfer is context dependent, so the caller can decide whether an unsafe concurrent copy is okay. We very likely want to set non_blocking=True
in our data loader, for example. The default case (non_blocking=False
) would skip the copy if the tensor is already contiguous and has the correct dtype, saving host memory. This make the default case slower, but it also makes it safer (avoiding OOMs and races) and more consistent with upstream/eager (where .to
is blocking by default).
If I understand correctly instead of creating a XLA:Literal, now we perform a copy on the cpu tensor, and copying the cpu tensor is faster?
Yeah. I couldn't tell you why this is faster, but it is.
* Copy input tensors before transfer * clone tensors before passing them to test case * formatting * Update test_utils.py
* Copy input tensors before transfer * clone tensors before passing them to test case * formatting * Update test_utils.py
* Copy input tensors before transfer * clone tensors before passing them to test case * formatting * Update test_utils.py
* Copy input tensors before transfer * clone tensors before passing them to test case * formatting * Update test_utils.py
This fixes a subtle edge case introduced with #5772. When we start a transfer asynchronously, this creates a window where changes to the source tensor may be partially reflected in the destination tensor.
Concretely, since our unit test implementation runs tests once on eager CPU and once on XLA, this can potentially double the changes in the test (ie make a change on CPU, that change moves to the XLA device, and we make the change again in the XLA executable). Our TPU CI is catching this intermittently in
TestAtenXlaTensor.test_diagonal_write_transposed_r3
.I was able to reproduce this case consistently by locally, patching OpenXLA to make transfers much slower, but I don't think this case is unit-testable.
Two fixes:
Clone the test input tensor before running the eager version.Somehow, copying the tensor within PyTorch is still much faster than copying the PyTorch tensor to an
xla::Literal
before #5772. The baseline forTransferToServerTime
in one epoch of ResNet50 on v4-8 was 20 seconds vs 250 ms now.Long term, I'd like to eliminate this fix and instead respect the
non_blocking
argument to.to()
. For blocking transfers, we don't need the copy to prevent concurrent changes, which will save host memory. For non-blocking transfers, we can either still perform the copy or simply caution users, since this is not the default case. Since that's a substantial change to behavior either way, I'd rather make it after the next release cut.Similar issues are possible for non-blocking transfers between GPU and CPU in eager mode. There don't seem to be any tools to lock changes to the source tensors that I can find, nor does there seem to be a general API to tell if a Tensor is ready. The recommended way to synchronize transfers is to check the CUDA device stream.