Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix : get_balanced_memory when using multi gpus with small models or quantized models with a large vocabulary #3244

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

MekkCyber
Copy link

What does this PR do?

Fixes huggingface/transformers#34751, and part of huggingface/transformers#34706.

Explanation of get_balanced_memory Behavior and Handling of Large embed_tokens Layers

When a small model or a relatively large quantized model is loaded using device_map=auto, the function get_balanced_memory calculates the maximum memory usage for each visible device. Its goal is to distribute the model evenly across all devices while reserving all available memory on the last device.

Issue with Large embed_tokens Layers

For models with a small number of parameters but a large vocabulary size (e.g., Gemma2 2B), the embed_tokens layer can consume a significant amount of memory. This layer might exceed the max_memory limit on the initial devices. As a result, the infer_auto_device_map function bypasses all earlier devices, placing the embed_tokens layer—and all subsequent layers—on the last device.

A similar issue arises with quantized models. Since embedding layers are often not quantized, the max_memory per device might be insufficient to accommodate the embed_tokens layer, leading to the same behavior.

Improvement in This PR

This pull request addresses the issue by comparing the size of the embed_tokens layer to the memory available per GPU (per_gpu). It adjusts the memory allocation strategy to ensure that embed_tokens can be distributed appropriately across devices, improving the handling of models with large embeddings.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@SunMarc

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines 1069 to 1074
if idx == 0 and not low_zero and module_sizes["model.embed_tokens"] > per_gpu * 0.9:
max_memory[idx] = min(module_sizes["model.embed_tokens"] * 1.3, max_memory[idx])
elif idx == 1 and low_zero and module_sizes["model.embed_tokens"] > per_gpu * 0.9:
max_memory[idx] = min(module_sizes["model.embed_tokens"] * 1.3, max_memory[idx])
else:
max_memory[idx] = min(max_memory[0] if low_zero and idx == 0 else per_gpu, max_memory[idx])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not every model have their embedding layers as model.embed_tokens. Since this is specific to transformers, I don't think this should live in accelerate too. We can modify max_memory directly there as it is computed. We are also trying to tackle similar issues with this PR #3066 (comment).

Maybe a good solution would be to check if there is a module that is > per_gpu, return a message saying that the model is unbalanced which will lead to the whole model being put on only one device and propose to the user to use "sequential" mode instead ? Or we could also do as you suggested and modify max_memory with the largest module size.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion @SunMarc, I updated the code to use the largest leave module size instead of hardcoding the embed_tokens layer

Comment on lines +1070 to +1073
if idx == 0 and not low_zero and max_leave_size > per_gpu * 0.9:
max_memory[idx] = min(max_leave_size * 1.3, max_memory[idx])
elif idx == 1 and low_zero and max_leave_size > per_gpu * 0.9:
max_memory[idx] = min(max_leave_size * 1.3, max_memory[idx])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are taking the minimum, is this expected ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we take the minimum between the max_memory of the gpu and the gpu space needed on the device. So if the space needed exceeds the space available on the gpu, we only allocate the space available.

@MekkCyber MekkCyber requested a review from SunMarc November 22, 2024 12:06
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test that fails before this PR but is fixed now ? Also it would be nice to try with the models that have this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Device_map='auto' not working along with bitsandbytes (transformers)
3 participants