You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We experienced wrong GPU placement when doing MoE with ZeRO Stage 3. We use module.id to control which expert to be loaded onto which GPU for finegrained controlm and we find out that module.id got corrupted after deepspeed.initialize.
id is an overly generic attribute name, might get easilly collided with some user-defined attributes.
There's no special check on .id attribute before setting it, this allows for accidental overwrites of the attribute, causing hard-to-diagnose problems.
In the specific bug we've encountered (bug.py provided below), each expert module is identified by the .id attribute, but during initialization, the .id is overwritten by the _register_hooks_recursively function in deepspeed/runtime/zero/stage3.py, leading to a mess on expert-GPU placement.
To reproduce
The following code in ZeRO Stage 3 is responsible for overwriting the .id attribute:
Install deepspeed 0.15.4
run bug.py using deepspeed --num_gpus=2 bug.py (num_gpus argument here doesn't matter, use 1 if you don't have multigpu nodes.)
importtorchimportdeepspeedfromtorch.nnimportModule, Linear# Define a simple expert moduleclassExpert(Module):
def__init__(self, id):
super().__init__()
self.id=id# ID for custom GPU placementself.fc=Linear(128, 128)
defforward(self, x):
returnself.fc(x)
# Create a model with 60 expertsclassMoEModel(Module):
def__init__(self):
super().__init__()
self.experts=torch.nn.ModuleList([Expert(i) foriinrange(60)])
defforward(self, x, expert_id):
returnself.experts[expert_id](x)
# Helper function to log expert idsdeflog_expert_ids(model, rank):
loaded_experts= [e.idforeinmodel.experts]
defmain():
deepspeed.init_distributed()
rank=torch.distributed.get_rank()
# Create modelmodel=MoEModel()
log_expert_ids(model, rank) # prints 0, 1, 2, .., 59# Configure DeepSpeedmodel_engine, optimizer, _, _=deepspeed.initialize(
model=model,
optimizer=torch.optim.Adam(model.parameters(), lr=3e-5),
config={
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": {"stage": 3,}
}
)
# print model ids again after deepspeed.initializelog_expert_ids(model, rank) # prints 0, 2, 4, 6, ...# if you do a deepspeed.intialize here again, you will see the id itself completely messed up.dummy_input=torch.randn(1, 128).cuda(rank)
forexpert_idinrange(60):
model_engine(dummy_input, expert_id=expert_id)
if__name__=="__main__":
main()
We print ids of all experts twice, one before deepspeed.initialize and one after that. Observe that the first print gives 0, 1, 2, ..., 59 while the second one gives 2, 4, 6, 8, .., 120
In this code, module.id is set to a value based on a counter (my_count), which conflicts with user-defined .id attributes used for expert placement.
Bug Significance
This bug can significantly affect model behavior when expert modules are incorrectly placed across GPUs, leading to incorrect training outcomes or potential crashes. Ensuring that internal DeepSpeed modifications do not overwrite user-defined attributes is crucial for stability and expected functionality.
Even if user-side conflicts are not in your scope, deepspeed itself can accidently modify these attributes as well. For example, you can reproduce the same problem by calling deepspeed.initialize multiple times.
Thus, we argue for two fixes / engineering practices for this issue.
Expected Behavior / Suggested Fix
Use a Specific Attribute for Internal IDs: Instead of overwriting .id, use a more specific attribute name such as _deepspeed_id to avoid conflicts with user-defined attributes.
Restrict Attribute Modification: Modify the __setattr__ method to only allow setting fields that have not been previously set, preventing unintentional overwrites of user-defined attributes.
@traincheck-team, please see #6847. I simplified your repro into a simple unit test. Please advise if this simplification is missing crucial aspects of this issue. Thanks!
Thanks for the fast response! @tjruwase
The test and implementation look good, and it guards against DeepSpeed modifying user-defined attributes.
I think the missing piece is to guard against DeepSpeed itself corrupting that ds_id attribute due to repeated initialization, but that should be handled in a separate PR that we've agreed to implement in #6770. We will try to ship that PR soon.
I think the missing piece is to guard against DeepSpeed itself corrupting that ds_id attribute due to repeated initialization, but that should be handled in a separate PR that we've agreed to implement in #6770. We will try to ship that PR soon.
Yes, for this problem, I think your PR is best solution as discussed. Thanks!
Description
We experienced wrong GPU placement when doing MoE with ZeRO Stage 3. We use
module.id
to control which expert to be loaded onto which GPU for finegrained controlm and we find out thatmodule.id
got corrupted afterdeepspeed.initialize
.Suspected Root Cause
DeepSpeed uses
.id
in ZeRO Stage 3 optimization to manage states, as seen inruntime/zero/parameter_offload.py:L271
.This practice is very brittle in that:
id
is an overly generic attribute name, might get easilly collided with some user-defined attributes..id
attribute before setting it, this allows for accidental overwrites of the attribute, causing hard-to-diagnose problems.In the specific bug we've encountered (bug.py provided below), each expert module is identified by the
.id
attribute, but during initialization, the.id
is overwritten by the_register_hooks_recursively
function indeepspeed/runtime/zero/stage3.py
, leading to a mess on expert-GPU placement.To reproduce
The following code in ZeRO Stage 3 is responsible for overwriting the
.id
attribute:Install deepspeed
0.15.4
run
bug.py
usingdeepspeed --num_gpus=2 bug.py
(num_gpus argument here doesn't matter, use 1 if you don't have multigpu nodes.)id
s of all experts twice, one before deepspeed.initialize and one after that. Observe that the first print gives0, 1, 2, ..., 59
while the second one gives2, 4, 6, 8, .., 120
In this code,
module.id
is set to a value based on a counter (my_count
), which conflicts with user-defined.id
attributes used for expert placement.Bug Significance
This bug can significantly affect model behavior when expert modules are incorrectly placed across GPUs, leading to incorrect training outcomes or potential crashes. Ensuring that internal DeepSpeed modifications do not overwrite user-defined attributes is crucial for stability and expected functionality.
Even if user-side conflicts are not in your scope, deepspeed itself can accidently modify these attributes as well. For example, you can reproduce the same problem by calling
deepspeed.initialize
multiple times.Thus, we argue for two fixes / engineering practices for this issue.
Expected Behavior / Suggested Fix
.id
, use a more specific attribute name such as_deepspeed_id
to avoid conflicts with user-defined attributes.__setattr__
method to only allow setting fields that have not been previously set, preventing unintentional overwrites of user-defined attributes.deepspeed.initialize
: We observe a lot of issue with accidental duplicate calls todeepspeed.initialize
. Thus we suggest to forbid duplicate calls by recording the models / optimizers that have already been inited, as mentioned in [BUG] [Fix-Suggested] KeyError in stage_1_and_2.py Due to Optimizer-Model Parameter Mismatch #6770 .ds_report output
Click to Show
I will be more than happy to contribute to the two suggested fixes, let me know what you think!
The text was updated successfully, but these errors were encountered: