Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory by using all_gather_into_tensor #1968

Merged
merged 12 commits into from
Oct 10, 2023
Merged

Reduce memory by using all_gather_into_tensor #1968

merged 12 commits into from
Oct 10, 2023

Conversation

muellerzr
Copy link
Collaborator

@muellerzr muellerzr commented Sep 13, 2023

What does this PR do?

torch 1.13 introduced publicly all_gather_into_tensor (before this was _all_gather_base) which is a much more memory efficient version of gather. One thing to note mentioned in the PR here is they did not have this be the base of gather, since it handles uneven inputs automatically. Since Accelerate does this separately and checks, we can safely use this API. Original DeepSpeed PR I discovered showing this

For a general idea of just how much memory can be stored, I ran a small simple test:

import time
import torch
from accelerate import PartialState
from accelerate.utils import gather

def convert_bytes(size):
    "Converts `size` from bytes to the largest possible unit"
    for x in ["bytes", "KB", "MB", "GB", "TB"]:
        if size < 1024.0:
            return f"{round(size, 2)} {x}"
        size /= 1024.0

    return f"{round(size, 2)} PB"

state = PartialState()
tensor = torch.rand((64, 224, 224, 64), device=state.device)

# Using `PartialState`
start_time = time.time()
tensor = gather(tensor)
end_time = time.time()

with state.main_process_first():
    print(f"Process {state.process_index} memory allocated after: {convert_bytes(torch.cuda.max_memory_allocated(state.device))}")
    print(f"Process {state.process_index} time: {end_time - start_time}")

The results can be summarized as such:

Before:
Total CUDA memory allocated: 3.83gb

After:
Total CUDA memory allocated: 2.3gb

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@BenjaminBossan @LysandreJik

@muellerzr muellerzr added the enhancement New feature or request label Sep 13, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 13, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice improvement. Before proceeding, I have a few questions, please take a look. Also some minor suggestions, but those are not blockers.

src/accelerate/utils/operations.py Show resolved Hide resolved
src/accelerate/utils/operations.py Outdated Show resolved Hide resolved
src/accelerate/utils/operations.py Outdated Show resolved Hide resolved
state = PartialState()

if state.backend is not None and state.backend != "gloo":
output_tensors = torch.zeros(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use torch.zeros instead of torch.empty_like as previously?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're doing something slightly different here with the new API, where this gather works using a different tensor dim than before which is more efficient.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, but won't this allocate more memory than previously?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually the opposite, hence what we're doing here.

src/accelerate/utils/operations.py Outdated Show resolved Hide resolved
Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @muellerzr for working on this, I have the same comments as Benjamin.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general looks good, thanks for adding support for this feature. I have only one question, but feel free to merge.

state = PartialState()

if state.backend is not None and state.backend != "gloo":
output_tensors = torch.zeros(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, but won't this allocate more memory than previously?

@muellerzr muellerzr requested a review from pacman100 October 2, 2023 18:15
@muellerzr muellerzr merged commit 73640d0 into main Oct 10, 2023
26 checks passed
@muellerzr muellerzr deleted the gather-op branch October 10, 2023 14:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants