diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 5d7c29b4f45..ed270660bbd 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -76,6 +76,7 @@ FalconModelPatcher, MistralModelPatcher, MusicgenModelPatcher, + Phi3ModelPatcher, SAMModelPatcher, SentenceTransformersCLIPPatcher, SentenceTransformersTransformerPatcher, @@ -304,6 +305,11 @@ class Phi3OnnxConfig(PhiOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfigWithGQA MIN_TRANSFORMERS_VERSION = version.parse("4.41.0") + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return Phi3ModelPatcher(self, model, model_kwargs=model_kwargs) + class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): # This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35 @@ -2156,8 +2162,8 @@ class Pix2StructOnnxConfig(OnnxSeq2SeqConfigWithPast): DummySeq2SeqPastKeyValuesGenerator, DummyPix2StructInputGenerator, ) - # Min operator needs to support int64, which is the case for opset>=12 - DEFAULT_ONNX_OPSET = 12 + + DEFAULT_ONNX_OPSET = 14 # use 'aten::triu' now which is opset 14 @property def inputs(self): diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 34ed5fcae46..cad4a418808 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -1155,3 +1155,92 @@ def __exit__(self, exc_type, exc_value, traceback): from transformers.models.clip.modeling_clip import CLIPSdpaAttention CLIPSdpaAttention.forward = self.original_sdpa_forward + + +# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Phi3 +def _prepare_4d_causal_attention_mask_with_cache_position_patched( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Any, + past_key_values: Any, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Phi3Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +from transformers import Phi3ForCausalLM + + +class Phi3ModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Phi3ForCausalLM, + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + + if _transformers_version >= version.parse("4.46.0"): + if hasattr(self._model, "model"): + self._model.model._prepare_4d_causal_attention_mask_with_cache_position = ( + _prepare_4d_causal_attention_mask_with_cache_position_patched + ) + else: + self._model._prepare_4d_causal_attention_mask_with_cache_position = ( + _prepare_4d_causal_attention_mask_with_cache_position_patched + )