Skip to content

Commit

Permalink
Update min dtype in falcon for prevent bf16 execution issue (#1093)
Browse files Browse the repository at this point in the history
* update min dtype in falcon for prevent bf16 execution issue

* Update optimum/exporters/openvino/model_patcher.py

* typo

---------

Co-authored-by: Ella Charlaix <[email protected]>
  • Loading branch information
eaidova and echarlaix authored Dec 25, 2024
1 parent 5fa9602 commit 48e72ef
Showing 1 changed file with 35 additions and 8 deletions.
43 changes: 35 additions & 8 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2501,6 +2501,40 @@ def __enter__(self):
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)


# Adapted from https://github.com/huggingface/transformers/blob/31f9a289a6207be6cae746e009d8e0db523be203/src/transformers/models/falcon/modeling_falcon.py#L1138
def _falcon_prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
# different from original: allow to provide min_dtype as parameter
min_dtype = torch.finfo(dtype).min if "min_dtype" not in kwargs else kwargs["min_dtype"]
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)

return causal_mask


def _falcon_update_causal_mask(
self,
attention_mask: torch.Tensor,
Expand All @@ -2520,13 +2554,6 @@ def _falcon_update_causal_mask(
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if hasattr(self, "_prepare_4d_causal_attention_mask_with_cache_position"):
_prepare_4d_causal_attention_mask_with_cache_position = (
self._prepare_4d_causal_attention_mask_with_cache_position
)
else:
from transformers.models.falcon.modeling_falcon import _prepare_4d_causal_attention_mask_with_cache_position

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down Expand Up @@ -2568,7 +2595,7 @@ def _falcon_update_causal_mask(
)

# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
causal_mask = _falcon_prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
Expand Down

0 comments on commit 48e72ef

Please sign in to comment.