-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use upstream XLA concurrency utilities #5799
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
test/cpp/test_replication.cpp
Outdated
@@ -57,7 +57,7 @@ void TestSingleReplication( | |||
|
|||
std::vector<std::vector<torch_xla::runtime::ComputationClient::DataPtr>> | |||
results(device_strings.size()); | |||
torch_xla::runtime::util::MultiWait mwait(device_strings.size()); | |||
absl::BlockingCounter mwait(device_strings.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's rename it to bc
or something, mwait
would be confusing if the type is not MultiWait
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, minor nits
* Use TSL threadpool * remove multiwait * fix test build * Move threadpool namespace * formatting * fix test build * Use BlockingCounter
* Use TSL threadpool * remove multiwait * fix test build * Move threadpool namespace * formatting * fix test build * Use BlockingCounter
* Use TSL threadpool * remove multiwait * fix test build * Move threadpool namespace * formatting * fix test build * Use BlockingCounter
* Distribute Literal->Tensor copies across thread pool * Update for #5799
* Distribute Literal->Tensor copies across thread pool * Update for pytorch#5799
* Distribute Literal->Tensor copies across thread pool * Update for pytorch#5799
* Use TSL threadpool * remove multiwait * fix test build * Move threadpool namespace * formatting * fix test build * Use BlockingCounter
* Distribute Literal->Tensor copies across thread pool * Update for pytorch#5799
* Use TSL threadpool * remove multiwait * fix test build * Move threadpool namespace * formatting * fix test build * Use BlockingCounter
* Distribute Literal->Tensor copies across thread pool * Update for #5799
* Use TSL threadpool * remove multiwait * fix test build * Move threadpool namespace * formatting * fix test build * Use BlockingCounter
* Distribute Literal->Tensor copies across thread pool * Update for #5799
@JackCaoG and I both have run into cases where our existing concurrency utilities add significant overhead (e.g. waiting >1ms for lock when completing a
MultiWait
task), which functionally limits the number of threads we can spawn. This PR replaces two custom implementations of common utilities (thread pool and latch/MultiWait
) with more optimized upstream equivalents.tsl::thread::ThreadPool
. The underlying implementation more carefully reuses threads and handles NUMA affinity to reduce context switching costs.Before:
After:
MultiWait
withabsl::BlockingCounter
. Completing a task requires acquiring a lock, which can cost >1ms in practice with enough threads.BlockingCounter
instead uses a lockless atomic counter, significantly reducing the latency to decrement the remaining task count.MultiWait
still exists upstream PyTorch if we need to use it for anything. I left usage of the upstreamMultiWait
alone.We already depend on absl and TSL through OpenXLA, so this PR adds no new dependencies.
Tested with SPMD llama inference to confirm no regression in performance:
Baseline:
With this PR:
I found these performance benefits while working on #5737, which spawns significantly more threads and potentially reduces the synchronous time spent in
ExecuteReplicated
. See my comment about performance. But, I don't expect this PR to provide significant benefits on its own. Benefits of this PR alone:Future work:
low_latency_hint
makes a difference in performance. Operations through the thread pool are assumed to be low latency by default.tf_
prefix from thread names.