Skip to content

Commit

Permalink
multi-gpu: fix tensor device placements for various models
Browse files Browse the repository at this point in the history
Fixes: huggingface#35762
Signed-off-by: Dmitry Rogozhkin <[email protected]>
  • Loading branch information
dvrogozh committed Jan 18, 2025
1 parent fcedb2f commit 0585ffc
Show file tree
Hide file tree
Showing 9 changed files with 17 additions and 11 deletions.
4 changes: 3 additions & 1 deletion src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def slow_forward(
) # [batch, intermediate_size, seq_len]
else:
conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
conv_state = conv_state.to(self.conv1d.weight.device)
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
if self.use_conv_bias:
hidden_states += self.conv1d.bias
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@ def inputs_merger(
special_image_token_mask = input_ids == self.image_token_id
new_inputs_embeds = inputs_embeds.clone()
reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size)
new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states.to(new_inputs_embeds.device)
return new_inputs_embeds

@add_start_docstrings_to_model_forward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1606,7 +1606,7 @@ def generate(
# otherwise we expand manually by concatenating
if getattr(self.config, "image_token_index", None) is not None:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device)
else:
logger.warning_once(
"Expanding inputs for image tokens in InstructBLIP should be done in processing. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1641,7 +1641,7 @@ def generate(
# otherwise we expand manually by concatenating
if getattr(self.config, "video_token_index", None) is not None:
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device)
else:
logger.warning_once(
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
else:
conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
conv_state = conv_state.to(self.conv1d.weight.device)
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
if self.use_conv_bias:
hidden_states += self.conv1d.bias
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,8 @@ def forward(
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
key_states.to(query_states.device),
value_states.to(query_states.device),
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
Expand Down Expand Up @@ -570,15 +570,15 @@ def forward(
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = residual + hidden_states.to(residual.device)

residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = residual + hidden_states.to(residual.device)

if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,7 +1486,7 @@ def get_rope_index(
)
image_index, video_index = 0, 0
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
input_ids = input_ids[attention_mask[i].to(input_ids.device) == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
Expand Down Expand Up @@ -1694,7 +1694,7 @@ def forward(
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.add(delta.to(position_ids.device))
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

outputs = self.model(
Expand Down

0 comments on commit 0585ffc

Please sign in to comment.