From 8a0c11dd3282b0dcf50e5b18e4c79e0debd564b0 Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Tue, 3 Oct 2023 09:49:56 +0200 Subject: [PATCH] Fix normalized config key for models architecture (#1408) --- optimum/onnxruntime/utils.py | 15 ++++++++------- optimum/utils/normalized_config.py | 14 +++++++------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 4152febb2d7..9f5178c46c0 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -101,8 +101,8 @@ class ORTConfigManager: "albert": "bert", "bart": "bart", "bert": "bert", - "big_bird": "bert", - # "bigbird_pegasus": None, # bug in `fusion_skiplayernorm.py` + "big-bird": "bert", + # "bigbird-pegasus": None, # bug in `fusion_skiplayernorm.py` "blenderbot": "bert", "bloom": "gpt2", "camembert": "bert", @@ -112,9 +112,9 @@ class ORTConfigManager: "distilbert": "bert", "electra": "bert", "gpt2": "gpt2", - "gpt_bigcode": "gpt2", - "gpt_neo": "gpt2", - "gpt_neox": "gpt2", + "gpt-bigcode": "gpt2", + "gpt-neo": "gpt2", + "gpt-neox": "gpt2", "gptj": "gpt2", # longt5 with O4 results in segmentation fault "longt5": "bert", @@ -122,7 +122,7 @@ class ORTConfigManager: "marian": "bart", "mbart": "bart", "mt5": "bart", - "m2m_100": "bart", + "m2m-100": "bart", "nystromformer": "bert", "pegasus": "bert", "roberta": "bert", @@ -134,6 +134,7 @@ class ORTConfigManager: @classmethod def get_model_ort_type(cls, model_type: str) -> str: + model_type = model_type.replace("_", "-") cls.check_supported_model(model_type) return cls._conf[model_type] @@ -161,7 +162,7 @@ def check_optimization_supported_model(cls, model_type: str, optimization_config "vit", "swin", ] - + model_type = model_type.replace("_", "-") if (model_type not in cls._conf) or (cls._conf[model_type] not in supported_model_types_for_optimization): raise NotImplementedError( f"ONNX Runtime doesn't support the graph optimization of {model_type} yet. Only {list(cls._conf.keys())} are supported. " diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 30bbec030ab..14dd17488c4 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -209,7 +209,7 @@ class NormalizedConfigManager: # "big_bird": NormalizedTextConfig, # "bigbird_pegasus": BartLikeNormalizedTextConfig, "blenderbot": BartLikeNormalizedTextConfig, - "blenderbot_small": BartLikeNormalizedTextConfig, + "blenderbot-small": BartLikeNormalizedTextConfig, "bloom": NormalizedTextConfig.with_args(num_layers="n_layer"), "falcon": NormalizedTextConfig.with_args(num_layers="num_hidden_layers", num_attention_heads="num_kv_heads"), "camembert": NormalizedTextConfig, @@ -223,9 +223,9 @@ class NormalizedConfigManager: "electra": NormalizedTextConfig, "encoder-decoder": NormalizedEncoderDecoderConfig, "gpt2": GPT2LikeNormalizedTextConfig, - "gpt-bigcode": GPT2LikeNormalizedTextConfig, - "gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"), - "gpt_neox": NormalizedTextConfig, + "gpt-bigcode": GPTBigCodeNormalizedTextConfig, + "gpt-neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"), + "gpt-neox": NormalizedTextConfig, "llama": NormalizedTextConfig, "gptj": GPT2LikeNormalizedTextConfig, "imagegpt": GPT2LikeNormalizedTextConfig, @@ -233,7 +233,7 @@ class NormalizedConfigManager: "marian": BartLikeNormalizedTextConfig, "mbart": BartLikeNormalizedTextConfig, "mt5": T5LikeNormalizedTextConfig, - "m2m_100": BartLikeNormalizedTextConfig, + "m2m-100": BartLikeNormalizedTextConfig, "nystromformer": NormalizedTextConfig, "opt": NormalizedTextConfig, "pegasus": BartLikeNormalizedTextConfig, @@ -242,7 +242,7 @@ class NormalizedConfigManager: "regnet": NormalizedVisionConfig, "resnet": NormalizedVisionConfig, "roberta": NormalizedTextConfig, - "speech_to_text": SpeechToTextLikeNormalizedTextConfig, + "speech-to-text": SpeechToTextLikeNormalizedTextConfig, "splinter": NormalizedTextConfig, "t5": T5LikeNormalizedTextConfig, "trocr": TrOCRLikeNormalizedTextConfig, @@ -252,7 +252,6 @@ class NormalizedConfigManager: "xlm-roberta": NormalizedTextConfig, "yolos": NormalizedVisionConfig, "mpt": MPTNormalizedTextConfig, - "gpt_bigcode": GPTBigCodeNormalizedTextConfig, } @classmethod @@ -266,5 +265,6 @@ def check_supported_model(cls, model_type: str): @classmethod def get_normalized_config_class(cls, model_type: str) -> Type: + model_type = model_type.replace("_", "-") cls.check_supported_model(model_type) return cls._conf[model_type]