-
Notifications
You must be signed in to change notification settings - Fork 989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix unwrap_model
for distributed compiled model.
#3282
base: main
Are you sure you want to change the base?
Conversation
fb3809e
to
f601b8c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ggoggam, this is actually intended behavior as you can see from this PR #1437.
What I can suggest modify either unwrap_model
or add a new argument to extract_model_from_parallel
. cc @muellerzr
extract_model_from_parallel
to fully unwrap compiled model.unwrap_model
for distributed compiled model.
I see. I think it would make more sense to modify |
Thanks for the PR. Regarding the implementation, |
My first commit actually fixes accelerate/tests/test_utils.py Line 252 in cb8b7c6
If I understand correctly, it should be assert compiled_model._orig_mod == compiled_model_unwrapped if |
Hmm, good question, I'll leave that to the others to answer, as I'm not sure. |
That's right. But the current behavior is that we don't unwrap the compiled model with cc @muellerzr what do you prefer, modify how |
src/accelerate/utils/other.py
Outdated
@@ -77,8 +77,7 @@ def extract_model_from_parallel(model, keep_fp32_wrapper: bool = True, recursive | |||
""" | |||
options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel) | |||
|
|||
is_compiled = is_compiled_module(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify: Is is_compiled
being assigned here and then checked after DDP/DS is unwrapped? If so, this seems potentially unnecessary or wrong as it checks for compilation only before unwrapping DDP/DS
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I agree that we should check for compilation after
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I think I misunderstood the code. is_compiled
is assigned here before unwrapping to keep the torch compilation, though I still think this could be a problem if the model is in the form of Distributed(Compiled(model))
. Refer to my newest commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Thanks for trying to tackle this. I think what I'd rather see instead is in extract_model_from_parallel
we should add a new param arg called keep_torch_compile
, which similar to keep_fp32_wrapper
should default to False
for a number of versions and then we likely should flip that to True after awhile, & modify the logic in extract_model_from_parallel
to reflect this change.
Can you tweak this PR to so? :)
Sure thing. I added I am also curious about what you think in the case of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this and the tests !
Also, note that we have to also upstream the modification done to extract_model_from_parallel
in transformers if needed.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
src/accelerate/accelerator.py
Outdated
@@ -2601,7 +2601,7 @@ def pad_across_processes(self, tensor, dim=0, pad_index=0, pad_first=False): | |||
""" | |||
return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first) | |||
|
|||
def unwrap_model(self, model, keep_fp32_wrapper: bool = True): | |||
def unwrap_model(self, model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you wanted to set keep_torch_compile to True for a couple of version @muellerzr ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, let's default to True
until 2.5.0, to give users ~3 months.
Essentially we should default to None
, and then if it gets None
warn that the default to this will be changing
Co-authored-by: Marc Sun <[email protected]>
What does this PR do?
This PR fixes the unexpected behavior of
Accelerator.unwrap_model
(Issue #3281). Right now, if the model is wrapped in both distributed wrapper (e.g.DistributedDataParallel
orDeepSpeedEngine
) and compiled module (OptimizedModule
) it only unwraps the distributed module. This behavior arises from the following code in L80 ofutils/others.py
:accelerate/src/accelerate/utils/other.py
Line 80 in cb8b7c6
Instead of checking for compiled model both before and after unwrapping distributed wrapper, the current code only checks for compilation before unwrapping the distributed wrapper. If the model is wrapped in both,
is_compiled
will be set toFalse
and won't unwrap the model fully, resulting in unexpected behavior (users expect fully unwrapped model before saving, but getsOptimizedModule
instead, which may result in an error when loading the state dict due to key mismatch).Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@muellerzr @SunMarc