Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi-gpu: test_model_parallel_beam_search tests fail with "RuntimeError: Expected all tensors to be on the same device" #35762

Open
dvrogozh opened this issue Jan 18, 2025 · 6 comments · May be fixed by #35763

Comments

@dvrogozh
Copy link
Contributor

dvrogozh commented Jan 18, 2025

With:

On:

  • 2 card Intel(R) Data Center GPU Max 1550 (aka PVC), note: each card has 2 tiles, in total there are 4 torch devices available

test_model_parallel_beam_search tests for a number of models fail with "RuntimeError: Expected all tensors to be on the same device":

# TRANSFORMERS_TEST_DEVICE_SPEC=spec.py python3 -m pytest -k test_model_parallel_beam_search \
  tests/models/aria \
  tests/models/falcon_mamba \
  tests/models/gpt2 \
  tests/models/gpt_bigcode \
  tests/models/idefics2 \
  tests/models/imagegpt \
  tests/models/instructblip \
  tests/models/instructblipvideo \
  tests/models/mamba \
  tests/models/mbart \
  tests/models/opt \
  tests/models/qwen2_vl \
  tests/models/xglm
...
FAILED tests/models/aria/test_modeling_aria.py::AriaForConditionalGenerationModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:1 and xpu:0!
FAILED tests/models/falcon_mamba/test_modeling_falcon_mamba.py::FalconMambaModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/gpt2/test_modeling_gpt2.py::GPT2ModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py::GPTBigCodeModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:1 and xpu:2!
FAILED tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py::GPTBigCodeMHAModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/idefics2/test_modeling_idefics2.py::Idefics2ForConditionalGenerationModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:1 and xpu:0!
FAILED tests/models/imagegpt/test_modeling_imagegpt.py::ImageGPTModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/instructblip/test_modeling_instructblip.py::InstructBlipForConditionalGenerationDecoderOnlyTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:2 and xpu:1!
FAILED tests/models/instructblipvideo/test_modeling_instructblipvideo.py::InstructBlipVideoForConditionalGenerationDecoderOnlyTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:2 and xpu:1!
FAILED tests/models/mamba/test_modeling_mamba.py::MambaModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/mbart/test_modeling_mbart.py::MBartModelTest::test_model_parallel_beam_search - RuntimeError: Expected query, key, and value to have the same device type, but got query.device: xpu:1 key.device: xp...
FAILED tests/models/mllama/test_modeling_mllama.py::MllamaForConditionalGenerationModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:1 and xpu:3! (when c...
FAILED tests/models/opt/test_modeling_opt.py::OPTModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/qwen2_vl/test_modeling_qwen2_vl.py::Qwen2VLModelTest::test_model_parallel_beam_search - RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (xpu:1)
FAILED tests/models/xglm/test_modeling_xglm.py::XGLMModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:1 and xpu:2!
=============================== 15 failed, 2 passed, 4399 deselected, 2 warnings in 13.95s ===============================

From the log, errors occur on the following lines:

padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]

hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)

hidden_states = inputs_embeds + position_embeds

hidden_states = inputs_embeds + position_embeds

hidden_states = inputs_embeds + position_embeds

new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states

hidden_states = inputs_embeds + position_embeds

inputs_embeds[special_image_mask] = language_model_inputs.flatten()

inputs_embeds[special_image_mask] = language_model_inputs.flatten()

hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)

attn_output = torch.nn.functional.scaled_dot_product_attention(

hidden_state = torch.cat([class_embedding, hidden_state], dim=1)

hidden_states = inputs_embeds + pos_embeds

input_ids = input_ids[attention_mask[i] == 1]

hidden_states = inputs_embeds + self.embed_positions(position_ids, past_key_values_length)

CC: @SunMarc @ydshieh @faaany

dvrogozh added a commit to dvrogozh/transformers that referenced this issue Jan 18, 2025
Fixing the following errors in few models:
```
>       hidden_states = inputs_embeds + pos_embeds
E       RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:2 and xpu:3!
```

Fixes: huggingface#35762
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Jan 18, 2025
@dvrogozh
Copy link
Contributor Author

@SunMarc
Copy link
Member

SunMarc commented Jan 20, 2025

I ran the tests with 2 GPUs (CUDA) and only one is failing. So, I don't think this is an issue with the modeling code. Could you explore a bit more why this happens only on xpu ?

FAILED tests/models/qwen2_vl/test_modeling_qwen2_vl.py::Qwen2VLModelTest::test_model_parallel_beam_search - RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:0)
=========================== 1 failed, 14 passed, 4135 deselected, 1 warning in 64.23s (0:01:04) ============================

@dvrogozh
Copy link
Contributor Author

@SunMarc : unfortunately I don't have 2+ CUDA system handy and can't compare execution side by side. What is happening with execution on XPU is clear regarding "how" - tensors involved in the same operation (like sum or whatever) are placed on the different devices. Next question is "why" we have this for XPU and you suggest that's not seen on CUDA. With the later I disagree - I found at least 2 changes in the existing Transformers sources which were fixing similar issues for CUDA. I left respective comments in PR-35763, reposting them here:

conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device)

hidden_states = inputs_embeds + positions.to(inputs_embeds.device)

I also see this discussion #30836 (comment) between you and @ArthurZucker on why this change is being made. It seems that issues similar to the one I filed here were seen for CUDA as well.

With the answering on "why", can you, please, give some guidance to me:

  1. I spotted issue for XPU in many models. Do you have concerns with all these changes or a subset of them? For example, there are fixes in my PR similar to the changes done previously for CUDA which I noted above. Should I separate such fixes and stand-alone PR and we will continue discussing the remainder?
  2. I think that the reason for "why" is how tensors were placed across multiple devices. And here I seek help and guidance since I am not quite familiar with this logic in Transformers. How Transformers manage tensors placement for device_map="auto"? Is there any guarantee that placement will be the same for different kind of devices based on some common characteristic, like for devices with the same amount of memory, etc.? Since if not, then why that's seen for XPU and not for CUDA is understandable - devices are different, so tensors placement is different, hence issues not seen before.

@dvrogozh
Copy link
Contributor Author

I ran the tests with 2 GPUs (CUDA) and only one is failing. So, I don't think this is an issue with the modeling code. Could you explore a bit more why this happens only on xpu ?

@SunMarc, as I noted in the previous comment, I think that issue is related with how model was distributed among multiple GPU devices. I had 4 XPU devices in initial filing of the issue with multiple tests failing. But if I will force only 2 GPU devices, then I see the same result as you - only one case fails for qwen. Same as for CUDA:

# ZE_AFFINITY_MASK=0,1 <same pytest command as in description>
...
FAILED tests/models/qwen2_vl/test_modeling_qwen2_vl.py::Qwen2VLModelTest::test_model_parallel_beam_search - RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (xpu:0)
===================== 1 failed, 14 passed, 4090 deselected, 1 warning in 9.78s =====================

Can someone try to run on the system with 4 CUDA devices? But overall, if I understand correctly, device_map distributes tensors declared in the model across multiple devices based on available device memory, distribution algorithm and some hints from model implementation like no_split_modules. Further it's needed to make sure that operation's tensor operands will be on the same device, but this highly depends on how tensors were initially distributed and it seems there is no automated logic to guarantee that those operands will be on the same device. If so, this falls to us to run model for multiple distribution cases and identify issues around distribution. Or implement some automation here which is likely challanging.

@SunMarc
Copy link
Member

SunMarc commented Jan 22, 2025

Thanks for giving more context and reproducing the results for 2 devices. I'll have a look at the PR then.

Can someone try to run on the system with 4 CUDA devices? But overall, if I understand correctly, device_map distributes tensors declared in the model across multiple devices based on available device memory, distribution algorithm and some hints from model implementation like no_split_modules. Further it's needed to make sure that operation's tensor operands will be on the same device, but this highly depends on how tensors were initially distributed and it seems there is no automated logic to guarantee that those operands will be on the same device. If so, this falls to us to run model for multiple distribution cases and identify issues around distribution. Or implement some automation here which is likely challanging.

That's pretty much it. We put modules that contain residual connection inside no_split_modules. However, we can't do that for all modules as they might represent more than 50% of the size of the model. Usually, we put the embedding layer or the decoder layer. If there are other residual connections or big modules that we can't really put inside no_split_modules, we need to do a manual dispatch as you saw in the other PRs. For example, we can't really put WhisperDecoder inside no_split_modules, so in the forward, we don't have a choice but to the positions to the same device as inputs_embeds:
hidden_states = inputs_embeds + positions.to(inputs_embeds.device).
Can you check that the modification you did is due to that ?

dvrogozh added a commit to dvrogozh/transformers that referenced this issue Jan 22, 2025
Fixing the following errors in few models:
```
>       hidden_states = inputs_embeds + pos_embeds
E       RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:2 and xpu:3!
```

Fixes: huggingface#35762
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Jan 22, 2025
@dvrogozh
Copy link
Contributor Author

Can you check that the modification you did is due to that ?

@SunMarc : most of the changes in the PR are around positions and inputs_embeds. Few other changes I've made:

  1. Placement of attention_mask, causal_mask, padding_mask
  2. Placement of hidden_states and residual
  3. Placement of query, keys and values in scaled_dot_product_attention calls

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants