Skip to content

Commit

Permalink
fix Phi3
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Oct 24, 2024
1 parent 29a5dbc commit c860ba9
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 2 deletions.
10 changes: 8 additions & 2 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
FalconModelPatcher,
MistralModelPatcher,
MusicgenModelPatcher,
Phi3ModelPatcher,
SAMModelPatcher,
SentenceTransformersCLIPPatcher,
SentenceTransformersTransformerPatcher,
Expand Down Expand Up @@ -304,6 +305,11 @@ class Phi3OnnxConfig(PhiOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfigWithGQA
MIN_TRANSFORMERS_VERSION = version.parse("4.41.0")

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return Phi3ModelPatcher(self, model, model_kwargs=model_kwargs)


class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
# This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35
Expand Down Expand Up @@ -2156,8 +2162,8 @@ class Pix2StructOnnxConfig(OnnxSeq2SeqConfigWithPast):
DummySeq2SeqPastKeyValuesGenerator,
DummyPix2StructInputGenerator,
)
# Min operator needs to support int64, which is the case for opset>=12
DEFAULT_ONNX_OPSET = 12

DEFAULT_ONNX_OPSET = 14 # use 'aten::triu' now which is opset 14

@property
def inputs(self):
Expand Down
89 changes: 89 additions & 0 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,3 +1155,92 @@ def __exit__(self, exc_type, exc_value, traceback):
from transformers.models.clip.modeling_clip import CLIPSdpaAttention

CLIPSdpaAttention.forward = self.original_sdpa_forward


# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Phi3
def _prepare_4d_causal_attention_mask_with_cache_position_patched(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: Any,
past_key_values: Any,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
config (`Phi3Config`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask


from transformers import Phi3ForCausalLM


class Phi3ModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Phi3ForCausalLM,
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

if _transformers_version >= version.parse("4.46.0"):
if hasattr(self._model, "model"):
self._model.model._prepare_4d_causal_attention_mask_with_cache_position = (
_prepare_4d_causal_attention_mask_with_cache_position_patched
)
else:
self._model._prepare_4d_causal_attention_mask_with_cache_position = (
_prepare_4d_causal_attention_mask_with_cache_position_patched
)

0 comments on commit c860ba9

Please sign in to comment.