Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fsdp refactoring #2177

Merged
merged 8 commits into from
Nov 24, 2023
Merged

Conversation

pacman100
Copy link
Contributor

@pacman100 pacman100 commented Nov 21, 2023

What does this PR do?

FSDP refactogin based on:

  1. Torch 2.1 official docs: https://pytorch.org/docs/stable/fsdp.html
  2. with use_orig_params=True, we no longer require preparing model before creating optimizer object. Earlier, we needed to prepare model, i.e., wrap the model with FSDP before creating optimizer object because of below warning from PyTorch official docs:
The optimizer must be initialized after the module has been wrapped with FSDP since FSDP will shard and transform the module’s parameters in a way that may not preserve the original parameter variables. Thus, the previously initialized optimizer may have stale references to the parameters.

Now, with use_orig_params=True, it is no longer the case. This makes the Accelerate training APi consistent, i.e., users using single GPU, DDP, FSDP, DeepSpeed now need to follow the same logic as below:

model, optimizer, lr_scheduler, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, lr_scheduler, train_dataloader, eval_dataloader)

Earlier, for FSDP, the recommended practice was shown as below. Else we used to receate the optimizer post preparing the model and it didn't preserve optimizer groups. Now, all that is resolved. Now, optimizer groups are also supported.

model = accelerator.prepare(model)

optim = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = ...
optimizer, lr_scheduler, train_dataloader, eval_dataloader = accelerator.prepare(optimizer, lr_scheduler, train_dataloader, eval_dataloader)

As such, use_orig_params=True is now the default.

  1. https://github.com/facebookresearch/llama-recipes: Using this as best practices guide for FSDP, we are inline with it for all the features and usage of FSDP APIs. For checkpointing, they support FULL_STATE_DICT and SHARDED_STATE_DICT. We are also supporting both of these and already have tests for it. They don't show how to save and load for LOCAL_STATE_DICT state dict type.
  2. Regression: A test for LOCAL_STATE_DICT checkpointing feature of FSDP is now failing. Couldn't find anything about it in llama recipes, FSDP documentation, torch FSDP codebase https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp as well as on the internet. Will raise an issue with PyTorch team regarding it.
  3. Ran all the slow tests for FSDP and all green! Updated documentation and examples.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 21, 2023

The documentation is not available anymore as the PR was closed or merged.

@pacman100 pacman100 requested review from muellerzr and BenjaminBossan and removed request for muellerzr November 21, 2023 12:20
@pacman100 pacman100 marked this pull request as ready for review November 21, 2023 12:21
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a great change, love to see so many lines deleted.

I don't have experience with FSDP, so a few questions:

  1. Does this still work as expected when using PyTorch < 2.1?
  2. use_orig_params default was changed to True. Is there any disadvantage to that, e.g. more memory usage?

src/accelerate/accelerator.py Show resolved Hide resolved
src/accelerate/accelerator.py Outdated Show resolved Hide resolved
@pacman100
Copy link
Contributor Author

pacman100 commented Nov 21, 2023

Hello Benjamin,

Does this still work as expected when using PyTorch < 2.1?
Accelerate will throw an error when FSDP integration is used with PyTorch < 2.1 now because of the lines below.

if is_torch_version("<", FSDP_PYTORCH_VERSION):
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")

use_orig_params default was changed to True. Is there any disadvantage to that, e.g. more memory usage?

more memory usage -> No
It is meant to enable below things:

  1. Multiple parameter groups https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019#tldr-1
  2. Allows non-uniform requires_grad during init, which means support for interspersed frozen and trainable parameters. Think PEFT wherein majority of params are frozen and only adapters are trainable.
  3. It allows to create optimizer object before wrapping the model in FSDP module (this is what simplified a lot of code in this PR).

It is expected to become default as per the above dev blogpost:

These semantics to use the original parameters are available today by passing use_orig_params=True to the FSDP constructor, and they were added exactly by augmenting the existing unshard/reshard logic. In that case, named_parameters() returns the original fully-qualified names (FQNs), not ones like .flat_param . This enables using multiple optimizer parameter groups and/or different requires_grad within one FlatParameter 's original parameters, and this helps hide the FlatParameter abstraction from users. We hope to converge to setting use_orig_params=True by default in the future.

@BenjaminBossan
Copy link
Member

Accelerate will throw an error when FSDP integration is used with PyTorch < 2.1

Ah okay, I missed the version bump, thanks for pointing me to it.

It is expected to become default as per the above dev blogpost

Thanks for providing more context. My question arose because use_orig_params=True seems to be strictly better and I wondered why it wasn't the default. I wondered if there is any disadvantage, but it seems to be mainly just for backwards compatibility in PyTorch.

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nicely done @pacman100! Excellent refactor and loving that diff. Keeping the simplistic API all around is a phenomenal win!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, thanks Sourab.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants