diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 215d65549f..f3a3ad78db 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -771,7 +771,8 @@ def __init__( ): super().__init__(config, model, model_kwargs) - self._update_causal_mask_original = self._model[0].auto_model._update_causal_mask + if _transformers_version >= version.parse("4.42") and self.real_config._config.model_type == "mistral": + self._update_causal_mask_original = self._model[0].auto_model._update_causal_mask def patched_forward(input_ids, attention_mask): result = self.orig_forward({"input_ids": input_ids, "attention_mask": attention_mask})