-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
multi-gpu: test_model_parallel_beam_search tests fail with "RuntimeError: Expected all tensors to be on the same device" #35762
Comments
Fixing the following errors in few models: ``` > hidden_states = inputs_embeds + pos_embeds E RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:2 and xpu:3! ``` Fixes: huggingface#35762 Signed-off-by: Dmitry Rogozhkin <[email protected]>
Fixes: huggingface#35762 Signed-off-by: Dmitry Rogozhkin <[email protected]>
See proposed fix in: |
I ran the tests with 2 GPUs (CUDA) and only one is failing. So, I don't think this is an issue with the modeling code. Could you explore a bit more why this happens only on xpu ?
|
@SunMarc : unfortunately I don't have 2+ CUDA system handy and can't compare execution side by side. What is happening with execution on XPU is clear regarding "how" - tensors involved in the same operation (like sum or whatever) are placed on the different devices. Next question is "why" we have this for XPU and you suggest that's not seen on CUDA. With the later I disagree - I found at least 2 changes in the existing Transformers sources which were fixing similar issues for CUDA. I left respective comments in PR-35763, reposting them here:
I also see this discussion #30836 (comment) between you and @ArthurZucker on why this change is being made. It seems that issues similar to the one I filed here were seen for CUDA as well. With the answering on "why", can you, please, give some guidance to me:
|
@SunMarc, as I noted in the previous comment, I think that issue is related with how model was distributed among multiple GPU devices. I had 4 XPU devices in initial filing of the issue with multiple tests failing. But if I will force only 2 GPU devices, then I see the same result as you - only one case fails for qwen. Same as for CUDA:
Can someone try to run on the system with 4 CUDA devices? But overall, if I understand correctly, |
Thanks for giving more context and reproducing the results for 2 devices. I'll have a look at the PR then.
That's pretty much it. We put modules that contain residual connection inside |
Fixing the following errors in few models: ``` > hidden_states = inputs_embeds + pos_embeds E RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:2 and xpu:3! ``` Fixes: huggingface#35762 Signed-off-by: Dmitry Rogozhkin <[email protected]>
Fixes: huggingface#35762 Signed-off-by: Dmitry Rogozhkin <[email protected]>
@SunMarc : most of the changes in the PR are around
|
With:
On:
test_model_parallel_beam_search
tests for a number of models fail with "RuntimeError: Expected all tensors to be on the same device":From the log, errors occur on the following lines:
transformers/src/transformers/models/aria/modeling_aria.py
Line 1116 in 7d4b3dd
transformers/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py
Line 312 in 7d4b3dd
transformers/src/transformers/models/gpt2/modeling_gpt2.py
Line 821 in 7d4b3dd
transformers/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Line 962 in 7d4b3dd
transformers/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Line 962 in 7d4b3dd
transformers/src/transformers/models/idefics2/modeling_idefics2.py
Line 1307 in 7d4b3dd
transformers/src/transformers/models/imagegpt/modeling_imagegpt.py
Line 778 in 7d4b3dd
transformers/src/transformers/models/instructblip/modeling_instructblip.py
Line 1609 in 7d4b3dd
transformers/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py
Line 1644 in 7d4b3dd
transformers/src/transformers/models/mamba/modeling_mamba.py
Line 264 in 7d4b3dd
transformers/src/transformers/models/mbart/modeling_mbart.py
Line 494 in 7d4b3dd
transformers/src/transformers/models/mllama/modeling_mllama.py
Line 1489 in 7d4b3dd
transformers/src/transformers/models/opt/modeling_opt.py
Line 885 in 7d4b3dd
transformers/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Line 1489 in 7d4b3dd
transformers/src/transformers/models/xglm/modeling_xglm.py
Line 597 in 7d4b3dd
CC: @SunMarc @ydshieh @faaany
The text was updated successfully, but these errors were encountered: