Skip to content

Commit

Permalink
Enable ONNX export decoder only without need to merge (#1257)
Browse files Browse the repository at this point in the history
* ONNX export decoder model refactorization

* fix IO bindings

* format

* enable mpt support

* format

* add trust remote code

* fix test

* format

* fix quantization

* add test

* format

* fix optimization

* fix compatibility with legacy models

* fix style

* add export to main_export

* add legacy to ONNX export


* patch model to fix causal lm generation

* add no post process

* remove bloom caching

* fix dynamic axis for position ids

* fix external data

* add model patcher

* format

* fix bart model patcher

* fix model patcher for opt models

* fix format

* add test

* format

* fix ort docker

* add test

* fix bart model patcher

* raise when unsupported model

* add cached file

* add position warning

* fixes

* enable post process after export to remove tied weights

* comment

* remove test

* fix test

* modify model

* remove deprecated use_merged in test

* Add mistral model patcher

* add slow test

* add workflow
  • Loading branch information
echarlaix authored Oct 16, 2023
1 parent c5ad7f9 commit 6e15777
Show file tree
Hide file tree
Showing 22 changed files with 999 additions and 714 deletions.
33 changes: 33 additions & 0 deletions .github/workflows/test_onnxruntime_slow.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: ONNX Runtime slow / Python - Test

on:
workflow_dispatch:
schedule:
- cron: 0 7 * * * # every day at 7am

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
build:
strategy:
fail-fast: false
matrix:
python-version: [3.8, 3.9]
os: [ubuntu-20.04]

runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies for export
run: |
pip install .[tests,onnxruntime]
- name: Test with unittest
working-directory: tests
run: |
RUN_SLOW=1 pytest onnxruntime -s -m "run_slow" --durations=0
18 changes: 9 additions & 9 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,6 @@ def parse_args_onnx(parser):
default=None,
help=("The library on the model." " If not provided, will attempt to infer the local checkpoint's library"),
)
optional_group.add_argument(
"--no-position-ids",
action="store_true",
help=(
"Disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum."
),
)

input_group = parser.add_argument_group(
"Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)."
)
Expand Down Expand Up @@ -217,6 +209,14 @@ def parse_args_onnx(parser):
default=DEFAULT_DUMMY_SHAPES["nb_points_per_image"],
help="For Segment Anything. It corresponds to the number of points per segmentation masks.",
)
optional_group.add_argument(
"--legacy",
action="store_true",
help=(
"Export decoder only models in three files (without + with past and the resulting merged model)."
"Also disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum."
),
)

# deprecated argument
parser.add_argument("--for-ort", action="store_true", help=argparse.SUPPRESS)
Expand Down Expand Up @@ -255,6 +255,6 @@ def run(self):
use_subprocess=True,
_variant=self.args.variant,
library_name=self.args.library_name,
no_position_ids=self.args.no_position_ids,
legacy=self.args.legacy,
**input_shapes,
)
21 changes: 11 additions & 10 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _get_submodels_and_onnx_configs(
float_dtype: str = "fp32",
fn_get_submodels: Optional[Callable] = None,
preprocessors: Optional[List[Any]] = None,
no_position_ids: bool = False,
legacy: bool = False,
):
is_stable_diffusion = "stable-diffusion" in task
if not custom_architecture:
Expand All @@ -82,8 +82,8 @@ def _get_submodels_and_onnx_configs(
model=model, exporter="onnx", task=task
)
onnx_config_kwargs = {}
if task.startswith("text-generation") and no_position_ids:
onnx_config_kwargs["no_position_ids"] = no_position_ids
if task.startswith("text-generation") and legacy:
onnx_config_kwargs["no_position_ids"] = legacy

onnx_config = onnx_config_constructor(
model.config,
Expand All @@ -106,7 +106,7 @@ def _get_submodels_and_onnx_configs(
):
models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config)
elif task.startswith("text-generation") and not monolith:
models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config)
models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config, legacy=legacy)
elif model.config.model_type == "sam":
models_and_onnx_configs = get_sam_models_for_export(model, onnx_config)
else:
Expand Down Expand Up @@ -184,7 +184,7 @@ def main_export(
use_subprocess: bool = False,
_variant: str = "default",
library_name: Optional[str] = None,
no_position_ids: bool = False,
legacy: bool = False,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -264,8 +264,8 @@ def main_export(
library_name (`Optional[str]`, defaults to `None`):
The library of the model(`"tansformers"` or `"diffusers"` or `"timm"`). If not provided, will attempt to automatically detect
the library name for the checkpoint.
no_position_ids (`bool`, defaults to `False`):
Disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum.
legacy (`bool`, defaults to `False`):
Disable the use of position_ids for text-generation models that require it for batched generation. Also enable to export decoder only models in three files (without + with past and the merged model). This argument is introduced for backward compatibility and will be removed in a future release of Optimum.
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.
Expand Down Expand Up @@ -353,9 +353,9 @@ def main_export(
is_stable_diffusion = "stable-diffusion" in task
model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-")

if no_position_ids and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and task.startswith("text-generation"):
if legacy and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and task.startswith("text-generation"):
logger.warning(
f"no_position_ids=True was specified in the ONNX export, although the model {model_name_or_path} (model type {model_type}) requires position_ids for batched inference. Passing `no_position_ids=True` is strongly discouraged, and this option will be removed in a future release. Reference: https://github.com/huggingface/optimum/pull/1381"
f"legacy=True was specified in the ONNX export, although the model {model_name_or_path} (model type {model_type}) requires position_ids for batched inference. Passing `legacy=True` is strongly discouraged, and this option will be removed in a future release. Reference: https://github.com/huggingface/optimum/pull/1381"
)

if not is_stable_diffusion:
Expand Down Expand Up @@ -424,7 +424,7 @@ def main_export(
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
_variant=_variant,
no_position_ids=no_position_ids,
legacy=legacy,
)

if not is_stable_diffusion:
Expand Down Expand Up @@ -610,6 +610,7 @@ def main():
pad_token_id=args.pad_token_id,
for_ort=args.for_ort,
library_name=args.library_name,
legacy=args.legacy,
**input_shapes,
)

Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
elif self.task == "feature-extraction":
common_outputs = OrderedDict({"last_hidden_state": {0: "batch_size"}})
else:
common_outputs = OrderedDict({"logits": {0: "batch_size"}})
common_outputs = OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}})
if self.use_past:
# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
self.add_past_key_values(common_outputs, direction="outputs")
Expand Down
7 changes: 2 additions & 5 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if self.use_past_in_inputs:
common_inputs = {"input_ids": {0: "batch_size"}}
common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}}
self.add_past_key_values(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"}
else:
Expand Down Expand Up @@ -164,10 +164,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
# generating wrong position_ids in the model itself:
# https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802
if not self.no_position_ids and self.task == "text-generation":
if self.use_past_in_inputs:
common_inputs["position_ids"] = {0: "batch_size"}
else:
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}

return common_inputs

Expand Down
46 changes: 41 additions & 5 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@
TextSeq2SeqOnnxConfig,
VisionOnnxConfig,
)
from .model_patcher import SAMModelPatcher, WavLMModelPatcher
from .model_patcher import (
BartModelPatcher,
BloomModelPatcher,
LlamaModelPatcher,
MistralModelPatcher,
OPTModelPatcher,
SAMModelPatcher,
WavLMModelPatcher,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -216,13 +224,23 @@ class OPTOnnxConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return OPTModelPatcher(self, model, model_kwargs=model_kwargs)


class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return LlamaModelPatcher(self, model, model_kwargs=model_kwargs)


class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
# The ONNX export of this architecture needs the Trilu operator support, available since opset 14
Expand All @@ -233,6 +251,11 @@ class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return MistralModelPatcher(self, model, model_kwargs=model_kwargs)


class MPTOnnxConfig(TextDecoderOnnxConfig):
# MPT does not require position_ids input.
Expand All @@ -241,6 +264,11 @@ class MPTOnnxConfig(TextDecoderOnnxConfig):
num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers"
)

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return BloomModelPatcher(self, model, model_kwargs=model_kwargs)


class BloomOnnxConfig(TextDecoderOnnxConfig):
# Bloom does not require position_ids input.
Expand Down Expand Up @@ -274,6 +302,11 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
1: decoder_sequence_name,
}

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return BloomModelPatcher(self, model, model_kwargs=model_kwargs)


class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
Expand Down Expand Up @@ -413,7 +446,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
return int_tensor


class BartOnnxConfig(TextSeq2SeqOnnxConfig):
class M2M100OnnxConfig(TextSeq2SeqOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
encoder_num_layers="encoder_layers",
decoder_num_layers="decoder_layers",
Expand Down Expand Up @@ -537,11 +570,14 @@ def flatten_past_key_values(self, flattened_output, name, idx, t):
)


class MBartOnnxConfig(BartOnnxConfig):
pass
class BartOnnxConfig(M2M100OnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return BartModelPatcher(self, model, model_kwargs=model_kwargs)


class M2M100OnnxConfig(BartOnnxConfig):
class MBartOnnxConfig(BartOnnxConfig):
pass


Expand Down
Loading

0 comments on commit 6e15777

Please sign in to comment.