-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] [Fix-Suggested] KeyError in stage_1_and_2.py Due to Optimizer-Model Parameter Mismatch #6770
Comments
@traincheck-team, thanks for the detailed report. Are you able to provide a fix? |
@tjruwase Personally, I think both should be done. 1 should be implemented to raise an exception when invoking 2 can be implemented to raise a friendly warning if On the downside, 1 limits API flexibility while 2 adds to already crowded log messages. Let me know what you think! |
@traincheck-team, thanks for agreeing to help and sharing the trade-offs. I prefer option 1 because nested |
Sure. Will ship the PR soon. |
Describe the bug
related to #3718
An KeyError is thrown inside
deepspeed.initialize
atruntime/zero/stage_1_and_2.py", line 574, in _create_param_mapping
, due to inconsistent usage of model parameters and parameters managed by the optimizer.Full Traceback (Click to Show)
Suspected Root Cause
deepspeed.initialize(model=model, optimizer=optimizer, config_params=ds_config_fp16)
This issue can be triggered in any case if in the arguments to
deepspeed.initialize
, parameters inoptimizer.param_groups
is not a subset ofmodel.parameters
.At the fault location, the code is trying to access parameter's names stored in
self.param_names
using tensors inself.bit16_groups
.self.bit16_groups
is populated fromoptimizer.param_groups
, whileself.param_names
is populated from the model itself.Thus, if the optimizer's parameters are not exactly a subset of the model, a
KeyError
will be thrown.The case where optimizer's parameters are not exactly a subset of the model is quite common, due to optimization techniques like Parameter Grouping and ZeRO Optimization.
To Reproduce
We prepared a rather simple reproduction script to reproduce this error. In this script,
deepspeed.initialize
is accidently called twice. After the firstdeepspeed.initialize
,optimizer.param_groups
was consolidated into one single parameter, and causes key error in the seconddeepspeed.initialize
.Install deepspeed
0.15.4
run
bug.py
usingdeepspeed --num_gpus=1 bug.py
Notice that the second
deepspeed.initialize
throws theKeyError
exception.Also notice that the first print of
optimizer.param_groups
shows 4 params, while the second print shows only one param (the content of one param is the merge of the 4 param).prior to
deepspeed.initialize
After
deepspeed.initialize
Since in the second
deepspeed.initialize
, the merged param actually does not exist in the model, an KeyError will be thrown.Expected behavior / Suggested Fix
We expect two behaviors here from DeepSpeed
deepspeed.initialize
on models / optimizers that have already been used in anotherdeepspeed.initialize
.optimizer.param_group
should be a subset ofmodel.parameters()
" explicitly and throw a more user-friendly exception or warning.ds_report output
Click to Show
I will be more than happy to contribute to the two suggested fixes, let me know what you think!
The text was updated successfully, but these errors were encountered: