Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix normalized config key for models architecture #1408

Merged
merged 3 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class ORTConfigManager:
"albert": "bert",
"bart": "bart",
"bert": "bert",
"big_bird": "bert",
# "bigbird_pegasus": None, # bug in `fusion_skiplayernorm.py`
"big-bird": "bert",
# "bigbird-pegasus": None, # bug in `fusion_skiplayernorm.py`
"blenderbot": "bert",
"bloom": "gpt2",
"camembert": "bert",
Expand All @@ -112,17 +112,17 @@ class ORTConfigManager:
"distilbert": "bert",
"electra": "bert",
"gpt2": "gpt2",
"gpt_bigcode": "gpt2",
"gpt_neo": "gpt2",
"gpt_neox": "gpt2",
"gpt-bigcode": "gpt2",
"gpt-neo": "gpt2",
"gpt-neox": "gpt2",
"gptj": "gpt2",
# longt5 with O4 results in segmentation fault
"longt5": "bert",
"llama": "gpt2",
"marian": "bart",
"mbart": "bart",
"mt5": "bart",
"m2m_100": "bart",
"m2m-100": "bart",
"nystromformer": "bert",
"pegasus": "bert",
"roberta": "bert",
Expand All @@ -134,6 +134,7 @@ class ORTConfigManager:

@classmethod
def get_model_ort_type(cls, model_type: str) -> str:
model_type = model_type.replace("_", "-")
cls.check_supported_model(model_type)
return cls._conf[model_type]

Expand Down Expand Up @@ -161,7 +162,7 @@ def check_optimization_supported_model(cls, model_type: str, optimization_config
"vit",
"swin",
]

model_type = model_type.replace("_", "-")
if (model_type not in cls._conf) or (cls._conf[model_type] not in supported_model_types_for_optimization):
raise NotImplementedError(
f"ONNX Runtime doesn't support the graph optimization of {model_type} yet. Only {list(cls._conf.keys())} are supported. "
Expand Down
14 changes: 7 additions & 7 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class NormalizedConfigManager:
# "big_bird": NormalizedTextConfig,
# "bigbird_pegasus": BartLikeNormalizedTextConfig,
"blenderbot": BartLikeNormalizedTextConfig,
"blenderbot_small": BartLikeNormalizedTextConfig,
"blenderbot-small": BartLikeNormalizedTextConfig,
"bloom": NormalizedTextConfig.with_args(num_layers="n_layer"),
"falcon": NormalizedTextConfig.with_args(num_layers="num_hidden_layers", num_attention_heads="num_kv_heads"),
"camembert": NormalizedTextConfig,
Expand All @@ -223,17 +223,17 @@ class NormalizedConfigManager:
"electra": NormalizedTextConfig,
"encoder-decoder": NormalizedEncoderDecoderConfig,
"gpt2": GPT2LikeNormalizedTextConfig,
"gpt-bigcode": GPT2LikeNormalizedTextConfig,
"gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
"gpt_neox": NormalizedTextConfig,
"gpt-bigcode": GPTBigCodeNormalizedTextConfig,
"gpt-neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
"gpt-neox": NormalizedTextConfig,
"llama": NormalizedTextConfig,
"gptj": GPT2LikeNormalizedTextConfig,
"imagegpt": GPT2LikeNormalizedTextConfig,
"longt5": T5LikeNormalizedTextConfig,
"marian": BartLikeNormalizedTextConfig,
"mbart": BartLikeNormalizedTextConfig,
"mt5": T5LikeNormalizedTextConfig,
"m2m_100": BartLikeNormalizedTextConfig,
"m2m-100": BartLikeNormalizedTextConfig,
"nystromformer": NormalizedTextConfig,
"opt": NormalizedTextConfig,
"pegasus": BartLikeNormalizedTextConfig,
Expand All @@ -242,7 +242,7 @@ class NormalizedConfigManager:
"regnet": NormalizedVisionConfig,
"resnet": NormalizedVisionConfig,
"roberta": NormalizedTextConfig,
"speech_to_text": SpeechToTextLikeNormalizedTextConfig,
"speech-to-text": SpeechToTextLikeNormalizedTextConfig,
"splinter": NormalizedTextConfig,
"t5": T5LikeNormalizedTextConfig,
"trocr": TrOCRLikeNormalizedTextConfig,
Expand All @@ -252,7 +252,6 @@ class NormalizedConfigManager:
"xlm-roberta": NormalizedTextConfig,
"yolos": NormalizedVisionConfig,
"mpt": MPTNormalizedTextConfig,
"gpt_bigcode": GPTBigCodeNormalizedTextConfig,
}

@classmethod
Expand All @@ -266,5 +265,6 @@ def check_supported_model(cls, model_type: str):

@classmethod
def get_normalized_config_class(cls, model_type: str) -> Type:
model_type = model_type.replace("_", "-")
cls.check_supported_model(model_type)
return cls._conf[model_type]