-
Notifications
You must be signed in to change notification settings - Fork 489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distribute Literal->Tensor copies across thread pool #5825
Conversation
b9d8a7c
to
8992668
Compare
@@ -796,13 +796,18 @@ std::vector<xla::Literal> ReleaseGilAndTransferData( | |||
|
|||
std::vector<at::Tensor> XlaDataToTensors( | |||
absl::Span<const torch::lazy::BackendDataPtr> xla_data, | |||
at::ScalarType dest_element_type) { | |||
absl::Span<const at::ScalarType> dest_element_type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason for this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems like we never really call XlaDataToTensors
with different dest_element_type
. @jonb377 are you introducing a new use case? if not we can keep it as a singleton?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason to make the dest_elem_types a vector is actually the next change - I'm batching the local shard transfers for many tensors into a single XlaDataToTensors call. I probably should have kept this refactor with the upcoming change... But it makes that PR slightly smaller.
std::vector<at::Tensor> tensors(literals.size()); | ||
absl::BlockingCounter counter(literals.size()); | ||
for (size_t i = 0; i < tensors.size(); ++i) { | ||
auto copy_fn = [&, i]() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you capture the variables you need explicitly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, you need just about every variable in this scope since it's pretty narrow. I take that back.
@@ -796,13 +796,18 @@ std::vector<xla::Literal> ReleaseGilAndTransferData( | |||
|
|||
std::vector<at::Tensor> XlaDataToTensors( | |||
absl::Span<const torch::lazy::BackendDataPtr> xla_data, | |||
at::ScalarType dest_element_type) { | |||
absl::Span<const at::ScalarType> dest_element_type) { | |||
std::vector<xla::Literal> literals = ReleaseGilAndTransferData(xla_data); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should just be returning Tensor
s here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm interested in making TransferFromServer return at::Tensor and cut out the xla::Literal middleman, but that's in the idea phase. Opted to keep this change smaller and just distribute the copy work over more cores.
Thanks for the reviews @will-cromar and @JackCaoG! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
8992668
to
fa65980
Compare
fa65980
to
c8f7315
Compare
* Distribute Literal->Tensor copies across thread pool * Update for pytorch#5799
* Distribute Literal->Tensor copies across thread pool * Update for pytorch#5799
* Distribute Literal->Tensor copies across thread pool * Update for pytorch#5799
* Distribute Literal->Tensor copies across thread pool * Update for #5799
* Distribute Literal->Tensor copies across thread pool * Update for #5799
After an xla::Literal has been created in TransferFromServer, it must be copied into an at::Tensor. This incurs a significant amount of overhead (up to 3x the transfer overhead after #5824). This is because the copies still occur synchronously on a single thread.
This change dispatches the copies to a thread pool to speed up the process. When checkpointing a 2B parameter model, the overhead decreases from ~5000ms to ~611ms.*
*Note: These benchmarks were prior to #5799 and used the old threading library.