Skip to content

Commit

Permalink
Fix normalized config key for models architecture (#1408)
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix authored Oct 3, 2023
1 parent ca7b9d8 commit 8a0c11d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
15 changes: 8 additions & 7 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class ORTConfigManager:
"albert": "bert",
"bart": "bart",
"bert": "bert",
"big_bird": "bert",
# "bigbird_pegasus": None, # bug in `fusion_skiplayernorm.py`
"big-bird": "bert",
# "bigbird-pegasus": None, # bug in `fusion_skiplayernorm.py`
"blenderbot": "bert",
"bloom": "gpt2",
"camembert": "bert",
Expand All @@ -112,17 +112,17 @@ class ORTConfigManager:
"distilbert": "bert",
"electra": "bert",
"gpt2": "gpt2",
"gpt_bigcode": "gpt2",
"gpt_neo": "gpt2",
"gpt_neox": "gpt2",
"gpt-bigcode": "gpt2",
"gpt-neo": "gpt2",
"gpt-neox": "gpt2",
"gptj": "gpt2",
# longt5 with O4 results in segmentation fault
"longt5": "bert",
"llama": "gpt2",
"marian": "bart",
"mbart": "bart",
"mt5": "bart",
"m2m_100": "bart",
"m2m-100": "bart",
"nystromformer": "bert",
"pegasus": "bert",
"roberta": "bert",
Expand All @@ -134,6 +134,7 @@ class ORTConfigManager:

@classmethod
def get_model_ort_type(cls, model_type: str) -> str:
model_type = model_type.replace("_", "-")
cls.check_supported_model(model_type)
return cls._conf[model_type]

Expand Down Expand Up @@ -161,7 +162,7 @@ def check_optimization_supported_model(cls, model_type: str, optimization_config
"vit",
"swin",
]

model_type = model_type.replace("_", "-")
if (model_type not in cls._conf) or (cls._conf[model_type] not in supported_model_types_for_optimization):
raise NotImplementedError(
f"ONNX Runtime doesn't support the graph optimization of {model_type} yet. Only {list(cls._conf.keys())} are supported. "
Expand Down
14 changes: 7 additions & 7 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class NormalizedConfigManager:
# "big_bird": NormalizedTextConfig,
# "bigbird_pegasus": BartLikeNormalizedTextConfig,
"blenderbot": BartLikeNormalizedTextConfig,
"blenderbot_small": BartLikeNormalizedTextConfig,
"blenderbot-small": BartLikeNormalizedTextConfig,
"bloom": NormalizedTextConfig.with_args(num_layers="n_layer"),
"falcon": NormalizedTextConfig.with_args(num_layers="num_hidden_layers", num_attention_heads="num_kv_heads"),
"camembert": NormalizedTextConfig,
Expand All @@ -223,17 +223,17 @@ class NormalizedConfigManager:
"electra": NormalizedTextConfig,
"encoder-decoder": NormalizedEncoderDecoderConfig,
"gpt2": GPT2LikeNormalizedTextConfig,
"gpt-bigcode": GPT2LikeNormalizedTextConfig,
"gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
"gpt_neox": NormalizedTextConfig,
"gpt-bigcode": GPTBigCodeNormalizedTextConfig,
"gpt-neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
"gpt-neox": NormalizedTextConfig,
"llama": NormalizedTextConfig,
"gptj": GPT2LikeNormalizedTextConfig,
"imagegpt": GPT2LikeNormalizedTextConfig,
"longt5": T5LikeNormalizedTextConfig,
"marian": BartLikeNormalizedTextConfig,
"mbart": BartLikeNormalizedTextConfig,
"mt5": T5LikeNormalizedTextConfig,
"m2m_100": BartLikeNormalizedTextConfig,
"m2m-100": BartLikeNormalizedTextConfig,
"nystromformer": NormalizedTextConfig,
"opt": NormalizedTextConfig,
"pegasus": BartLikeNormalizedTextConfig,
Expand All @@ -242,7 +242,7 @@ class NormalizedConfigManager:
"regnet": NormalizedVisionConfig,
"resnet": NormalizedVisionConfig,
"roberta": NormalizedTextConfig,
"speech_to_text": SpeechToTextLikeNormalizedTextConfig,
"speech-to-text": SpeechToTextLikeNormalizedTextConfig,
"splinter": NormalizedTextConfig,
"t5": T5LikeNormalizedTextConfig,
"trocr": TrOCRLikeNormalizedTextConfig,
Expand All @@ -252,7 +252,6 @@ class NormalizedConfigManager:
"xlm-roberta": NormalizedTextConfig,
"yolos": NormalizedVisionConfig,
"mpt": MPTNormalizedTextConfig,
"gpt_bigcode": GPTBigCodeNormalizedTextConfig,
}

@classmethod
Expand All @@ -266,5 +265,6 @@ def check_supported_model(cls, model_type: str):

@classmethod
def get_normalized_config_class(cls, model_type: str) -> Type:
model_type = model_type.replace("_", "-")
cls.check_supported_model(model_type)
return cls._conf[model_type]

0 comments on commit 8a0c11d

Please sign in to comment.