From 067587ce1bc2b039a069538a76200aa524d04e95 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 31 Oct 2024 17:08:40 +0100 Subject: [PATCH] fix mpt --- optimum/exporters/onnx/model_configs.py | 1 + tests/onnxruntime/test_modeling.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 82486524ce0..10d38fae76d 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -343,6 +343,7 @@ def patch_model_for_export( class MPTOnnxConfig(TextDecoderOnnxConfig): # MPT does not require position_ids input. DEFAULT_ONNX_OPSET = 13 + MIN_TRANSFORMERS_VERSION = version.parse("4.41.0") NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers" ) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 9187b851fc0..e90bb4be758 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -2325,7 +2325,6 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): "gptj", "llama", "mistral", - "mpt", "opt", ] @@ -2335,8 +2334,9 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): if check_if_transformers_greater("4.38"): SUPPORTED_ARCHITECTURES.append("gemma") + # TODO: fix "mpt" for which inference fails for transformers < v4.41 if check_if_transformers_greater("4.41"): - SUPPORTED_ARCHITECTURES.append("phi3") + SUPPORTED_ARCHITECTURES.extend(["phi3", "mpt"]) FULL_GRID = { "model_arch": SUPPORTED_ARCHITECTURES, @@ -2449,7 +2449,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach transformers_model = AutoModelForCausalLM.from_pretrained(model_id) transformers_model = transformers_model.eval() tokenizer = get_preprocessor(model_id) - tokens = tokenizer("This is a sample output", return_tensors="pt") + tokens = tokenizer("This is a sample input", return_tensors="pt") position_ids = None if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS: input_shape = tokens["input_ids"].shape @@ -2471,7 +2471,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach # Compare batched generation. tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" - tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True) + tokens = tokenizer(["This is", "This is a sample input"], return_tensors="pt", padding=True) onnx_model.generation_config.eos_token_id = None transformers_model.generation_config.eos_token_id = None onnx_model.config.eos_token_id = None