Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distribute Literal->Tensor copies across thread pool #5825

Merged
merged 2 commits into from
Dec 1, 2023

Conversation

jonb377
Copy link
Collaborator

@jonb377 jonb377 commented Nov 20, 2023

After an xla::Literal has been created in TransferFromServer, it must be copied into an at::Tensor. This incurs a significant amount of overhead (up to 3x the transfer overhead after #5824). This is because the copies still occur synchronously on a single thread.

This change dispatches the copies to a thread pool to speed up the process. When checkpointing a 2B parameter model, the overhead decreases from ~5000ms to ~611ms.*

*Note: These benchmarks were prior to #5799 and used the old threading library.

@@ -796,13 +796,18 @@ std::vector<xla::Literal> ReleaseGilAndTransferData(

std::vector<at::Tensor> XlaDataToTensors(
absl::Span<const torch::lazy::BackendDataPtr> xla_data,
at::ScalarType dest_element_type) {
absl::Span<const at::ScalarType> dest_element_type) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for this change?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like we never really call XlaDataToTensors with different dest_element_type. @jonb377 are you introducing a new use case? if not we can keep it as a singleton?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason to make the dest_elem_types a vector is actually the next change - I'm batching the local shard transfers for many tensors into a single XlaDataToTensors call. I probably should have kept this refactor with the upcoming change... But it makes that PR slightly smaller.

std::vector<at::Tensor> tensors(literals.size());
absl::BlockingCounter counter(literals.size());
for (size_t i = 0; i < tensors.size(); ++i) {
auto copy_fn = [&, i]() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you capture the variables you need explicitly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you need just about every variable in this scope since it's pretty narrow. I take that back.

@@ -796,13 +796,18 @@ std::vector<xla::Literal> ReleaseGilAndTransferData(

std::vector<at::Tensor> XlaDataToTensors(
absl::Span<const torch::lazy::BackendDataPtr> xla_data,
at::ScalarType dest_element_type) {
absl::Span<const at::ScalarType> dest_element_type) {
std::vector<xla::Literal> literals = ReleaseGilAndTransferData(xla_data);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should just be returning Tensors here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm interested in making TransferFromServer return at::Tensor and cut out the xla::Literal middleman, but that's in the idea phase. Opted to keep this change smaller and just distribute the copy work over more cores.

@jonb377
Copy link
Collaborator Author

jonb377 commented Nov 20, 2023

Thanks for the reviews @will-cromar and @JackCaoG!

Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jonb377 jonb377 merged commit ec54fd4 into master Dec 1, 2023
19 checks passed
@jonb377 jonb377 deleted the jonbolin/copy-pool branch December 1, 2023 01:55
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 1, 2023
* Distribute Literal->Tensor copies across thread pool

* Update for pytorch#5799
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 1, 2023
* Distribute Literal->Tensor copies across thread pool

* Update for pytorch#5799
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
* Distribute Literal->Tensor copies across thread pool

* Update for pytorch#5799
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
* Distribute Literal->Tensor copies across thread pool

* Update for #5799
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
* Distribute Literal->Tensor copies across thread pool

* Update for #5799
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants