Skip to content

Commit

Permalink
milti-gpu: fix inputs_embeds + position_embeds
Browse files Browse the repository at this point in the history
Fixing the following errors in few models:
```
>       hidden_states = inputs_embeds + pos_embeds
E       RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:2 and xpu:3!
```

Fixes: huggingface#35762
Signed-off-by: Dmitry Rogozhkin <[email protected]>
  • Loading branch information
dvrogozh committed Jan 18, 2025
1 parent 7d4b3dd commit fcedb2f
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)

# Attention mask.
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)

if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/imagegpt/modeling_imagegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)

if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ def forward(
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)

hidden_states = inputs_embeds + pos_embeds
hidden_states = inputs_embeds + pos_embeds.to(inputs_embeds.device)

if self.gradient_checkpointing and self.training:
if use_cache:
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/xglm/modeling_xglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,9 @@ def forward(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)

hidden_states = inputs_embeds + self.embed_positions(position_ids, past_key_values_length)
hidden_states = inputs_embeds + self.embed_positions(position_ids, past_key_values_length).to(
inputs_embeds.device
)
hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training)

if self.gradient_checkpointing and self.training:
Expand Down

0 comments on commit fcedb2f

Please sign in to comment.