-
Notifications
You must be signed in to change notification settings - Fork 989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
create _preprare_fsdp to pre- prepare fsdp model training #3213
base: main
Are you sure you want to change the base?
create _preprare_fsdp to pre- prepare fsdp model training #3213
Conversation
@muellerzr have you any feedback? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this important issue.
I wonder on the overall approach to solving this issue. Right now, IIUC, the idea is to re-initialize the optimizer(s) using the FSDP-wrapped models. This could potentially be error prone and ideally should be well tested before we merge.
Would it be possible instead to check if this has happened and raise an error, then direct users towards not initializing the optimizer prematurely, and handle the initialization with accelerate so that it's in the right order? Maybe that can't work, but I think it's worth considering different solutions.
Also, in case this is really the same issue as this one reported in PEFT, it means that it used to work correctly in previous versions of accelerate/transformers. In that case, I wonder what changed that resulted in the issue.
# Validate the presence of models and optimizers | ||
if not models and not optimizers: | ||
return args | ||
|
||
# Flattening weights implies that the optimizers have already been processed. | ||
if next(next(iter(models.values())).named_parameters())[0].endswith("_flat_param"): | ||
return args | ||
|
||
if len(models) != len(optimizers): | ||
raise ValueError( | ||
f"The number of models ({len(models)}) must be equal to the number of optimizers ({len(optimizers)})." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to move these checks to the very start of the method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BenjaminBossan The method! Do you mean .prepare
? What are the benefits of doing so?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant _prepare_fsdp
. It is a common pattern to perform all checks as early as possible, so following this makes the code easier to understand for readers. This is especially so if there are early returns.
The reason why it's common is so that we don't do any unnecessary work if the checks fail anyway. In this case, there is no need to determine models
or optimizers
if we're going to raise an error later. By skipping the unnecessary work, we ensure faster execution and prevent possibly unwanted side-effects (this might not be relevant here right now but code will change in the future and then it could be true).
src/accelerate/accelerator.py
Outdated
# Clear parameter lists. | ||
for opt in optimizers.values(): | ||
for group in opt.param_groups: | ||
group["params"] = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it better to call group["params"].clear()
? That would affect references.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates and explaining your testing and reasoning further.
Raising an error when the optimizer is initialized prematurely could be a solution, but two cases emerge
These are valid concerns. At the end of the day, it's a trade off and we need to decide which cost we'd rather pay.
The Transformers Trainer internally uses delay_optimizer_creation and creates the optimizer after FSDP wrapping.
I wonder if this logic could be moved/copied to accelerate.
For PEFT training, which involves mixing frozen and non-frozen parameters, use_orig_params=True must be used
I don't think this is strictly necessary, we have examples in PEFT with use_orig_params=False
. But I have to admit I don't know what exactly changes under the hood in FSDP when this parameter is set. Note also that in the linked PEFT issue, I tried setting use_orig_params=True
and it didn't help.
# Validate the presence of models and optimizers | ||
if not models and not optimizers: | ||
return args | ||
|
||
# Flattening weights implies that the optimizers have already been processed. | ||
if next(next(iter(models.values())).named_parameters())[0].endswith("_flat_param"): | ||
return args | ||
|
||
if len(models) != len(optimizers): | ||
raise ValueError( | ||
f"The number of models ({len(models)}) must be equal to the number of optimizers ({len(optimizers)})." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant _prepare_fsdp
. It is a common pattern to perform all checks as early as possible, so following this makes the code easier to understand for readers. This is especially so if there are early returns.
The reason why it's common is so that we don't do any unnecessary work if the checks fail anyway. In this case, there is no need to determine models
or optimizers
if we're going to raise an error later. By skipping the unnecessary work, we ensure faster execution and prevent possibly unwanted side-effects (this might not be relevant here right now but code will change in the future and then it could be true).
|
This is true but using the right FSDP auto wrap policy, trainable and frozen parameters should be prevented from being mixed.
If this is an option here, it sounds like the more robust solution to me. Let's wait for @muellerzr return to office and get his opinion on this. |
This feels a bit overengineered and the commit to transformers broke a lot of users code with saving checkpoints. For now let's not move forward with this IMO, since otherwise this would imply that the prior change broke all FSDP code, which isn't the case. |
Fixes #3209
Who can review?