From 154f7eba05186ad3bb1b2e4c98cf0fbdf202dc53 Mon Sep 17 00:00:00 2001 From: Guang Yang Date: Tue, 5 Nov 2024 16:35:48 -0800 Subject: [PATCH 01/30] Export to ExecuTorch: Initial Integration --- .github/workflows/test_executorch_export.yml | 35 ++ .github/workflows/test_executorch_runtime.yml | 35 ++ docs/source/exporters/executorch/overview.mdx | 26 + .../package_reference/configuration.mdx | 54 ++ .../executorch/package_reference/export.mdx | 26 + .../executorch/usage_guides/contribute.mdx | 57 +++ .../usage_guides/export_a_model.mdx | 124 +++++ docs/source/exporters/overview.mdx | 2 +- optimum/commands/__init__.py | 2 +- optimum/commands/export/__init__.py | 1 + optimum/commands/export/base.py | 6 + optimum/commands/export/executorch.py | 67 +++ optimum/executorchruntime/__init__.py | 29 ++ .../executorchruntime/modeling_executorch.py | 464 ++++++++++++++++++ optimum/exporters/executorch/__init__.py | 44 ++ optimum/exporters/executorch/__main__.py | 160 ++++++ optimum/exporters/executorch/convert.py | 90 ++++ .../exporters/executorch/recipe_registry.py | 68 +++ .../exporters/executorch/recipes/__init__.py | 11 + .../exporters/executorch/recipes/xnnpack.py | 97 ++++ optimum/exporters/executorch/task_registry.py | 68 +++ .../exporters/executorch/tasks/__init__.py | 11 + .../exporters/executorch/tasks/causal_lm.py | 66 +++ optimum/onnxruntime/runs/__init__.py | 6 +- setup.py | 4 + tests/executorch/export/__init__.py | 14 + .../export/test_exporters_executorch.py | 115 +++++ tests/executorch/runtime/__init__.py | 14 + tests/executorch/runtime/test_modeling.py | 207 ++++++++ 29 files changed, 1898 insertions(+), 5 deletions(-) create mode 100644 .github/workflows/test_executorch_export.yml create mode 100644 .github/workflows/test_executorch_runtime.yml create mode 100644 docs/source/exporters/executorch/overview.mdx create mode 100644 docs/source/exporters/executorch/package_reference/configuration.mdx create mode 100644 docs/source/exporters/executorch/package_reference/export.mdx create mode 100644 docs/source/exporters/executorch/usage_guides/contribute.mdx create mode 100644 docs/source/exporters/executorch/usage_guides/export_a_model.mdx create mode 100644 optimum/commands/export/executorch.py create mode 100644 optimum/executorchruntime/__init__.py create mode 100644 optimum/executorchruntime/modeling_executorch.py create mode 100644 optimum/exporters/executorch/__init__.py create mode 100644 optimum/exporters/executorch/__main__.py create mode 100644 optimum/exporters/executorch/convert.py create mode 100644 optimum/exporters/executorch/recipe_registry.py create mode 100644 optimum/exporters/executorch/recipes/__init__.py create mode 100644 optimum/exporters/executorch/recipes/xnnpack.py create mode 100644 optimum/exporters/executorch/task_registry.py create mode 100644 optimum/exporters/executorch/tasks/__init__.py create mode 100644 optimum/exporters/executorch/tasks/causal_lm.py create mode 100644 tests/executorch/export/__init__.py create mode 100644 tests/executorch/export/test_exporters_executorch.py create mode 100644 tests/executorch/runtime/__init__.py create mode 100644 tests/executorch/runtime/test_modeling.py diff --git a/.github/workflows/test_executorch_export.yml b/.github/workflows/test_executorch_export.yml new file mode 100644 index 00000000000..eb8f995f71c --- /dev/null +++ b/.github/workflows/test_executorch_export.yml @@ -0,0 +1,35 @@ +name: ExecuTorch Export / Python - Test + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + strategy: + fail-fast: false + matrix: + python-version: ['3.10', '3.11', '3.12'] + os: [ubuntu-20.04, macos-15] + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies for ExecuTorch + run: | + pip install .[tests,exporters-executorch] + pip list + - name: Run tests + working-directory: tests + run: | + RUN_SLOW=1 pytest executorch/export/test_*.py -s -vvvv --durations=0 diff --git a/.github/workflows/test_executorch_runtime.yml b/.github/workflows/test_executorch_runtime.yml new file mode 100644 index 00000000000..f7e3abcceff --- /dev/null +++ b/.github/workflows/test_executorch_runtime.yml @@ -0,0 +1,35 @@ +name: ExecuTorch Runtime / Python - Test + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + strategy: + fail-fast: false + matrix: + python-version: ['3.10', '3.11', '3.12'] + os: [ubuntu-20.04, macos-15] + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies for ExecuTorch + run: | + pip install .[tests,exporters-executorch] + pip list + - name: Run tests + working-directory: tests + run: | + RUN_SLOW=1 pytest executorch/runtime/test_*.py -s -vvvv --durations=0 diff --git a/docs/source/exporters/executorch/overview.mdx b/docs/source/exporters/executorch/overview.mdx new file mode 100644 index 00000000000..0e880968bf7 --- /dev/null +++ b/docs/source/exporters/executorch/overview.mdx @@ -0,0 +1,26 @@ + + +# Overview + +πŸ€— Optimum handles the export of PyTorch to ExecuTorch in the `exporters.executorch` module. It provides classes, functions, and a command line interface to perform the export easily. + +Supported architectures from [πŸ€— Transformers](https://huggingface.co/docs/transformers/index): + +- Gemma +- Gemma2 +- Llama2 +- Llama3(Llama3.2) +- OLMo +- Qwen2(Qwen2.5) + +There are many more models are supported by ExecuTorch, we will add those models to Optimum over time. Read more at [pytorch/executorch/examples/](https://github.com/pytorch/executorch/tree/main/examples) diff --git a/docs/source/exporters/executorch/package_reference/configuration.mdx b/docs/source/exporters/executorch/package_reference/configuration.mdx new file mode 100644 index 00000000000..b7a10b80419 --- /dev/null +++ b/docs/source/exporters/executorch/package_reference/configuration.mdx @@ -0,0 +1,54 @@ + + +# Configuration for ExecuTorch Export + +ExecuTorch export provides a flexible configuration mechanism through dynamic registration, enabling users to have +complete control over the export process. The configuration system is divided into task configurations and recipe +configurations, each addressing specific aspects of the export pipeline. + + +## Task Configurations + +Task configurations determine how a Hugging Face model should be loaded and prepared for export, tailored to specific tasks. + +For instance, when exporting a model for a text generation task, the provided configuration utilizes **static caching** and +**SDPA (Scaled Dot-Product Attention)** for inference optimization. + +By leveraging task configurations, users can ensure that their models are appropriately prepared for efficient execution on +the ExecuTorch backend. + +[[autodoc]] exporters.executorch.task_registry.discover_tasks + +[[autodoc]] exporters.executorch.task_registry.register_task + +[[autodoc]] exporters.executorch.tasks.causal_lm.load_causal_lm_model + + +## Recipe Configurations + +Recipe configurations control the specifics of lowering an eager PyTorch module to the ExecuTorch backend. These +configurations allow users to: + +- Specify whether and how to **quantize** the model. +- Delegate computation to various accelerators, such as **CPU**, **GPU**, **NPU**, **DSP**, and others. +- Define **custom transformation passes**. +- Implement advanced techniques like memory planning algorithms to optimize resource utilization. + +[[autodoc]] exporters.executorch.recipe_registry.discover_recipes + +[[autodoc]] exporters.executorch.recipe_registry.register_recipe + +[[autodoc]] exporters.executorch.recipes.xnnpack.export_to_executorch_with_xnnpack + +The combination of task and recipe configurations ensures that users can customize both the high-level task setup +and the low-level export details to suit their deployment requirements. diff --git a/docs/source/exporters/executorch/package_reference/export.mdx b/docs/source/exporters/executorch/package_reference/export.mdx new file mode 100644 index 00000000000..6663eb5278e --- /dev/null +++ b/docs/source/exporters/executorch/package_reference/export.mdx @@ -0,0 +1,26 @@ + + +# Export functions + +## Main functions + +[[autodoc]] exporters.executorch.convert.export_to_executorch + +The primary export function is designed to be **model- and task-independent** as well as **optimization-agnostic**, providing a +highly flexible and modular interface for exporting Hugging Face models to the ExecuTorch backend. + +This approach highlights the **composability** of ExecuTorch export pipeline, where dynamically registered **task configurations** +specify how a :hug model is prepared, and **recipe configurations** encapsulate device-specific optimizations during export. This +separation allows users to customize the export process without altering the core function. + +For more details on task and recipe configurations, see the [Configuration for ExecuTorch Export](./configuration.mdx). diff --git a/docs/source/exporters/executorch/usage_guides/contribute.mdx b/docs/source/exporters/executorch/usage_guides/contribute.mdx new file mode 100644 index 00000000000..2c6c1593169 --- /dev/null +++ b/docs/source/exporters/executorch/usage_guides/contribute.mdx @@ -0,0 +1,57 @@ + + +# Adding support for an unsupported architecture + +We welcome contributions to extend the functionality of ExecuTorch export. This guide provides high-level instructions for contributors who want to: + +1. Export a new model that is not currently supported. +2. Add new recipes or support a new task for export. + +--- + +## Exporting a New Model + +If you want to export a model that is not already supported by the library, follow these steps: + +### Step 1: Export and Test the Model +1. Attempt to export and lower the model using an existing task and recipe. On success, it will store the exported model in a `.pte` file. +2. Add a test case for the model in the appropriate test suite. + - For example, you can make sure tests pass for the new `my_new_model` by running: + ```bash + pytest tests/executorch/export/test_*.py -k "test_my_new_model" # doctest: +SKIP + pytest tests/executorch/runtime/test_*.py -k "test_my_new_model" # doctest: +SKIP + ``` + +### Step 2: Handle Export Failures +1. If the export fails in Step 1, report the issue by opening a GitHub issue. +2. If the issue requires changes to the model’s architecture or its Hugging Face implementation, these modifications may be made upstream in the Hugging Face Transformers library. + +--- + +## Adding New Recipes or Tasks + +To extend ExecuTorch with new recipes or tasks, follow these guidelines: + +### Registering a New Recipe +You can add a custom recipe to define specific optimizations or configurations for exporting models. Below is an example: + +```python +from exporters.executorch import register_recipe + +@register_recipe("my_custom_recipe") +def export_with_custom_recipe(model, config, *args, **kwargs): + # Example: Apply a custom quantization +``` + +### Registering a Task +The task registration process is same as adding a recipe. Besides that you may need to implement a new `ExecuTorchModelForXXX` class. diff --git a/docs/source/exporters/executorch/usage_guides/export_a_model.mdx b/docs/source/exporters/executorch/usage_guides/export_a_model.mdx new file mode 100644 index 00000000000..7993188cbd5 --- /dev/null +++ b/docs/source/exporters/executorch/usage_guides/export_a_model.mdx @@ -0,0 +1,124 @@ + + +# Export a model to ExecuTorch with optimum.exporters.executorch + +If you need to deploy πŸ€— Transformers models for on-device use cases, we recommend +exporting them to a serialized format that can be distributed and executed on specialized +runtimes and hardware. In this guide, we'll show you how to export these +models to [ExecuTorch](https://pytorch.org/executorch/main/intro-overview.html). + + +## Why ExecuTorch? + +ExecuTorch is the ideal solution for deploying PyTorch models on edge devices, offering a streamlined process from +export to deployment without leaving PyTorch ecosystem. + +Supporting on-device AI presents unique challenges with diverse hardware, critical power requirements, low/no internet +connectivity, and realtime processing needs. These constraints have historically prevented or slowed down the creation +of scalable and performant on-device AI solutions. We designed ExecuTorch, backed by our industry partners like Meta, +Arm, Apple, Qualcomm, MediaTek, etc. to be highly portable and provide superior developer productivity without losing on +performance. + + +## Summary + +Exporting a PyTorch model to ExecuTorch is as simple as + +```bash +optimum-cli export executorch --model "meta-llama/Llama-3.2-1B" --task "text-generation" --recipe "xnnpack" --output_dir "meta_llama3_2_1b" +``` + +Check out the help for more options: + +```bash +optimum-cli export executorch --help +``` + + +## Exporting a model to ExecuTorch using the CLI + +To export a πŸ€— Transformers model to ExecuTorch, you'll first need to install some extra +dependencies: + +```bash +pip install optimum[exporters-executorch] +``` + +The Optimum ExecuTorch export can be used through Optimum command-line: + +```bash +optimum-cli export executorch --help + +usage: optimum-cli export executorch [-h] -m MODEL [-o OUTPUT_DIR] [--task TASK] [--recipe RECIPE] + +options: + -h, --help show this help message and exit + +Required arguments: + -m MODEL, --model MODEL + Model ID on huggingface.co or path on disk to load model from. + -o OUTPUT_DIR, --output_dir OUTPUT_DIR + Path indicating the directory where to store the generated ExecuTorch model. + --task TASK The task to export the model for. Available tasks depend on the model, but are among: ['audio-classification', 'feature-extraction', 'image-to-text', + 'sentence-similarity', 'depth-estimation', 'image-segmentation', 'audio-frame-classification', 'masked-im', 'semantic-segmentation', 'text-classification', + 'audio-xvector', 'mask-generation', 'question-answering', 'text-to-audio', 'automatic-speech-recognition', 'image-to-image', 'multiple-choice', 'image- + classification', 'text2text-generation', 'token-classification', 'object-detection', 'zero-shot-object-detection', 'zero-shot-image-classification', 'text- + generation', 'fill-mask']. + --recipe RECIPE Pre-defined recipes for export to ExecuTorch. Defaults to "xnnpack". + +``` + +Exporting a checkpoint can be done as follows: + +```bash +optimum-cli export executorch --model "meta-llama/Llama-3.2-1B" --task "text-generation" --recipe "xnnpack" --output_dir "meta_llama3_2_1b" +``` + +You should see a `model.pte` file is stored under "./meta_llama3_2_1b/": + +```bash +meta_llama3_2_1b/ +└── model.pte +``` + +This will fetch the model on the Hub and exports the PyTorch model with the specialized recipe. The resulting `model.pte` file can then be run on the [XNNPACK backend](https://pytorch.org/executorch/main/tutorial-xnnpack-delegate-lowering.html), or on many +other ExecuTorh supported backends if exports with different recipes, e.g. Apple's [Core ML](https://pytorch.org/executorch/main/build-run-coreml.html) or [MPS](https://pytorch.org/executorch/main/build-run-mps.html), [Qualcomm's SoCs](https://pytorch.org/executorch/main/build-run-qualcomm-ai-engine-direct-backend.html), [ARM's Ethos-U](https://pytorch.org/executorch/main/executorch-arm-delegate-tutorial.html), [Xtensa HiFi4 DSP](https://pytorch.org/executorch/main/build-run-xtensa.html), [Vulkan GPU](https://pytorch.org/executorch/main/build-run-vulkan.html), [MediaTek](https://pytorch.org/executorch/main/build-run-mediatek-backend.html), etc. + +For example, we can load and run the model with [ExecuTorch +Runtime](https://pytorch.org/executorch/main/runtime-overview.html) using the `optimum.executorchruntime` package as follows: + +```python +>>> from transformers import AutoTokenizer +>>> from optimum.executorchruntime import ExecuTorchModelForCausalLM + +>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") # doctest: +SKIP +>>> model = ExecuTorchModelForCausalLM.from_pretrained("meta_llama3_2_1b/", export=False) # doctest: +SKIP + +>>> generated_text = model.text_generation(tokenizer=tokenizer, prompt="Simply put, the theory of relativity states that", max_seq_len=45) # doctest: +SKIP +``` + +Printing the `generated_text` would give that: + +``` +"Simply put, the theory of relativity states that the laws of physics are the same in all inertial frames of reference. In other words, the laws of physics are the same in all inertial frames of reference." +``` + +As you can see, converting a model to ExecuTorch does not mean leaving the Hugging Face ecosystem. You end up with a similar API as regular πŸ€— Transformers models! + +It is also possible to export the model to ExecuTorch directly from the `ExecuTorchModelForCausalLM` class by doing the following: + +```python +>>> from optimum.executorchruntime import ExecuTorchModelForCausalLM + +>>> model = ExecuTorchModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", export=True, task="text-generation", recipe="xnnpack") +``` diff --git a/docs/source/exporters/overview.mdx b/docs/source/exporters/overview.mdx index 6fd7bd9d916..2b4c2e11792 100644 --- a/docs/source/exporters/overview.mdx +++ b/docs/source/exporters/overview.mdx @@ -12,4 +12,4 @@ specific language governing permissions and limitations under the License. # Overview -πŸ€— Optimum enables exporting models from PyTorch or TensorFlow to different formats through its `exporters` module. For now, two exporting format are supported: ONNX and TFLite (TensorFlow Lite). +πŸ€— Optimum enables exporting models from PyTorch or TensorFlow to different formats through its `exporters` module. For now, three exporting format are supported: ONNX, TFLite (TensorFlow Lite), and ExecuTorch. diff --git a/optimum/commands/__init__.py b/optimum/commands/__init__.py index 8a2a276d1c5..a31344ed133 100644 --- a/optimum/commands/__init__.py +++ b/optimum/commands/__init__.py @@ -14,5 +14,5 @@ from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand from .env import EnvironmentCommand -from .export import ExportCommand, ONNXExportCommand, TFLiteExportCommand +from .export import ExecuTorchExportCommand, ExportCommand, ONNXExportCommand, TFLiteExportCommand from .optimum_cli import optimum_cli_subcommand diff --git a/optimum/commands/export/__init__.py b/optimum/commands/export/__init__.py index 19da68a60d2..b72cd5dbc8d 100644 --- a/optimum/commands/export/__init__.py +++ b/optimum/commands/export/__init__.py @@ -14,5 +14,6 @@ from .base import ExportCommand +from .executorch import ExecuTorchExportCommand from .onnx import ONNXExportCommand from .tflite import TFLiteExportCommand diff --git a/optimum/commands/export/base.py b/optimum/commands/export/base.py index 07737cb8eaf..e5ed4c90ff5 100644 --- a/optimum/commands/export/base.py +++ b/optimum/commands/export/base.py @@ -15,6 +15,7 @@ """optimum.exporters command-line interface base classes.""" from .. import BaseOptimumCLICommand, CommandInfo +from .executorch import ExecuTorchExportCommand from .onnx import ONNXExportCommand from .tflite import TFLiteExportCommand @@ -25,6 +26,11 @@ class ExportCommand(BaseOptimumCLICommand): help="Export PyTorch and TensorFlow models to several format.", ) SUBCOMMANDS = ( + CommandInfo( + name="executorch", + help="Export PyTorch model to ExecuTorch.", + subcommand_class=ExecuTorchExportCommand, + ), CommandInfo( name="onnx", help="Export PyTorch and TensorFlow to ONNX.", diff --git a/optimum/commands/export/executorch.py b/optimum/commands/export/executorch.py new file mode 100644 index 00000000000..2bf2f1d3054 --- /dev/null +++ b/optimum/commands/export/executorch.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""Defines the command line for the export with ExecuTorch.""" + +from pathlib import Path +from typing import TYPE_CHECKING + +from ...exporters import TasksManager +from ..base import BaseOptimumCLICommand + + +if TYPE_CHECKING: + from argparse import ArgumentParser + + +def parse_args_executorch(parser): + required_group = parser.add_argument_group("Required arguments") + required_group.add_argument( + "-m", "--model", type=str, required=True, help="Model ID on huggingface.co or path on disk to load model from." + ) + required_group.add_argument( + "-o", + "--output_dir", + type=Path, + help="Path indicating the directory where to store the generated ExecuTorch model.", + ) + required_group.add_argument( + "--task", + type=str, + default="text-generation", + help=( + "The task to export the model for. Available tasks depend on the model, but are among:" + f" {str(TasksManager.get_all_tasks())}." + ), + ) + required_group.add_argument( + "--recipe", + type=str, + default="xnnpack", + help='Pre-defined recipes for export to ExecuTorch. Defaults to "xnnpack".', + ) + + +class ExecuTorchExportCommand(BaseOptimumCLICommand): + @staticmethod + def parse_args(parser: "ArgumentParser"): + return parse_args_executorch(parser) + + def run(self): + from ...exporters.executorch import main_export + + main_export( + model_name_or_path=self.args.model, + task=self.args.task, + recipe=self.args.recipe, + output_dir=self.args.output_dir, + ) diff --git a/optimum/executorchruntime/__init__.py b/optimum/executorchruntime/__init__.py new file mode 100644 index 00000000000..0a84c3a139b --- /dev/null +++ b/optimum/executorchruntime/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import _LazyModule + + +_import_structure = { + "modeling_executorch": [ + "ExecuTorchModelForCausalLM", + ], +} + +if TYPE_CHECKING: + from .modeling_executorch import ExecuTorchModelForCausalLM +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/optimum/executorchruntime/modeling_executorch.py b/optimum/executorchruntime/modeling_executorch.py new file mode 100644 index 00000000000..39c75a03863 --- /dev/null +++ b/optimum/executorchruntime/modeling_executorch.py @@ -0,0 +1,464 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""ExecuTorchModelForXXX classes, allowing to run ExecuTorch Models with ExecuTorch Runtime using the same API as Transformers.""" + +import logging +import os +import warnings +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, List, Optional, Union + +import torch +from executorch.extension.pybindings.portable_lib import ( + ExecuTorchModule, + _load_for_executorch, +) +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from transformers import ( + AutoModelForCausalLM, + PretrainedConfig, + PreTrainedTokenizer, +) + +from ..exporters.executorch import main_export +from ..modeling_base import OptimizedModel + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + +logger = logging.getLogger(__name__) + + +class ExecuTorchModelForCausalLM(OptimizedModel): + """ + ExecuTorch model with a causal language modeling head for inference using the ExecuTorch Runtime. + + This class provides an interface for loading, running, and generating outputs from a causal language model + optimized for ExecuTorch Runtime. It includes utilities for exporting and loading pre-trained models + compatible with ExecuTorch runtime. + + Attributes: + auto_model_class (`Type`): + Associated Transformers class, `AutoModelForCausalLM`. + et_model (`ExecuTorchModule`): + The loaded ExecuTorch model. + use_kv_cache (`bool`): + Whether key-value caching is enabled. For performance reasons, the exported model is + optimized to use a static cache. + max_cache_size (`int`): + Maximum sequence length supported by the cache. + max_batch_size (`int`): + Maximum supported batch size. + dtype (`str`): + Data type of the model parameters. + bos_token_id (`int`): + Beginning-of-sequence token ID. + eos_token_id (`int`): + End-of-sequence token ID. + vocab_size (`int`): + Size of the model vocabulary. + """ + + auto_model_class = AutoModelForCausalLM + + def __init__( + self, + model: "ExecuTorchModule", + config: "PretrainedConfig", + ): + super().__init__(model, config) + self.et_model = model + metadata = self.et_model.method_names() + logging.info(f"Load all static methods: {metadata}") + if "use_kv_cache" in metadata: + self.use_kv_cache = self.et_model.run_method("use_kv_cache")[0] + if "get_max_seq_len" in metadata: + self.max_cache_size = self.et_model.run_method("get_max_seq_len")[0] + if "get_max_batch_size" in metadata: + self.max_batch_size = self.et_model.run_method("get_max_batch_size")[0] + if "get_dtype" in metadata: + self.dtype = self.et_model.run_method("get_dtype")[0] + if "get_bos_id" in metadata: + self.bos_token_id = self.et_model.run_method("get_bos_id")[0] + if "get_eos_id" in metadata: + self.eos_token_id = self.et_model.run_method("get_eos_id")[0] + if "get_vocab_size" in metadata: + self.vocab_size = self.et_model.run_method("get_vocab_size")[0] + + def forward( + self, + input_ids: torch.Tensor, + cache_position: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the model, which is compatible with the ExecuTorch runtime for LLM. + + Args: + input_ids (`torch.Tensor`): Tensor representing current input token id to the model. + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + + Returns: + torch.Tensor: Logits output from the model. + """ + return self.et_model.forward((input_ids, cache_position))[0] + + @classmethod + def from_pretrained( + cls, + model_name_or_path: Union[str, Path], + export: bool = True, + task: str = "", + recipe: str = "", + config: "PretrainedConfig" = None, + subfolder: str = "", + revision: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + **kwargs, + ) -> "ExecuTorchModelForCausalLM": + """ + Load a pre-trained ExecuTorch model. + + Args: + model_name_or_path (`Union[str, Path]`): + Model ID on huggingface.co or path on disk to the model repository to export. Example: `model_name_or_path="meta-llama/Llama-3.2-1B"` or `mode_name_or_path="/path/to/model_folder`. + export (`bool`, *optional*, defaults to `True`): + If `True`, the model will be exported from eager to ExecuTorch after fetched from huggingface.co. `model_name_or_path` must be a valid model ID on huggingface.co. + If `False`, the previously exported ExecuTorch model will be loaded from a local path. `model_name_or_path` must be a valid local directory where a `model.pte` is stored. + task (`str`, defaults to `""`): + The task to export the model for, e.g. "text-generation". It is required to specify a task when `export` is `True`. + recipe (`str`, defaults to `""`): + The recipe to use to do the export, e.g. "xnnpack". It is required to specify a task when `export` is `True`. + config (`PretrainedConfig`, *optional*): + Configuration of the pre-trained model. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + **kwargs: + Additional configuration options to tasks and recipes. + + Returns: + `ExecuTorchModelForCausalLM`: An instance of the ExecuTorch model for text generation task. + """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + if export: + # Fetch the model from huggingface.co and export it to ExecuTorch + if task == "": + raise ValueError("Please specify a task to export the model for.") + if recipe == "": + raise ValueError("Please specify a recipe to export the model for.") + return cls._export( + model_id=model_name_or_path, + task=task, + recipe=recipe, + config=config, + **kwargs, + ) + else: + # Load the ExecuTorch model from a local path + return cls._from_pretrained( + model_dir_path=model_name_or_path, + config=config, + ) + + @classmethod + def _from_pretrained( + cls, + model_dir_path: Union[str, Path], + config: PretrainedConfig, + subfolder: str = "", + revision: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + ) -> "ExecuTorchModelForCausalLM": + """ + Load a pre-trained ExecuTorch model from a local directory. + + Args: + model_dir_path (`Union[str, Path]`): + Path to the directory containing the ExecuTorch model file (`model.pte`). + config (`PretrainedConfig`, *optional*): + Configuration of the pre-trained model. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + + Returns: + `ExecuTorchModelForCausalLM`: The initialized ExecuTorch model. + + """ + full_path = os.path.join(f"{model_dir_path}", "model.pte") + model = _load_for_executorch(full_path) + logging.info(f"Loaded model from {full_path}") + logging.debug(f"{model.method_meta('forward')}") + return cls( + model=model, + config=config, + ) + + def _save_pretrained(self, save_directory): + """ + Saves a model weights into a directory, so that it can be re-loaded using the + [`from_pretrained`] class method. + """ + raise NotImplementedError + + @classmethod + def _export( + cls, + model_id: str, + task: str, + recipe: str, + config: PretrainedConfig, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + trust_remote_code: bool = False, + subfolder: str = "", + revision: Optional[str] = None, + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + **kwargs, + ): + """ + Fetch a model from the Hugging Face Hub and export it to ExecuTorch format. + + Args: + model_id (`str`): + Model ID on huggingface.co, for example: `model_name_or_path="meta-llama/Llama-3.2-1B"`. + task (`str`): + The task to export the model for, e.g. "text-generation". + recipe (`str`): + The recipe to use to do the export, e.g. "xnnpack". + config (`PretrainedConfig`, *optional*): + Configuration of the pre-trained model. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + trust_remote_code (`bool`, defaults to `False`): + Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories + you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the + model repository. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + **kwargs: + Additional configuration options to tasks and recipes. + + Returns: + `ExecuTorchModelForCausalLM`: The loaded and exported ExecuTorch model. + + """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + save_dir = TemporaryDirectory() + save_dir_path = Path(save_dir.name) + + # Export to ExecuTorch and save the pte file to the temporary directory + main_export( + model_name_or_path=model_id, + output_dir=save_dir_path, + task=task, + recipe=recipe, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + return cls._from_pretrained( + model_dir_path=save_dir_path, + config=config, + use_auth_token=use_auth_token, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + force_download=force_download, + ) + + def generate( + self, + prompt_tokens: List[int], + echo: bool = False, + pos_base: int = 0, + max_seq_len: Optional[int] = None, + ) -> List[int]: + """ + Generate tokens from a prompt using the ExecuTorch model. + + Args: + prompt_tokens (List[int]): + List of token IDs representing the prompt. + echo (`bool`, *optional*): + Whether to include prompt tokens in the generated output. Defaults to `False`. + pos_base (`int`, *optional*): + Base position for the prompt tokens. Defaults to 0. + max_seq_len (`int`, *optional*): + Maximum sequence length for the generated output. + Defaults to None and uses the model's `max_cache_size` attribute. + Will be truncated to maximal cache size if larger than `max_cache_size`. + + Returns: + List[int]: List of generated token IDs. + + Note: + Temporarily implemented this method in Python due to limited access to ExecuTorch's c++ LLM runner via pybind. + Expect improvements to the pybind interface in ExecuTorch version 0.4.1. + """ + self.device = torch.device("cpu") + if max_seq_len is None: + # Default to max_cache_size if max_seq_len is not specified + max_seq_len = self.max_cache_size + elif max_seq_len > self.max_cache_size: + logging.warning( + f"max_seq_len={max_seq_len} is larger than max_cache_size={self.max_cache_size}. Generating tokens will be truncated to max_cache_size." + ) + max_seq_len = self.max_cache_size + generated_tokens = [] + + # prefill + for i, prompt_token in enumerate(prompt_tokens): + logits = self.forward( + input_ids=torch.tensor([prompt_token], dtype=torch.long, device=self.device).unsqueeze(0), + cache_position=torch.tensor([i], dtype=torch.long, device=self.device), + ) + + next_token = torch.argmax(logits, dim=-1).item() + generated_tokens = prompt_tokens + [next_token] + + while len(generated_tokens) < max_seq_len: + logits = self.forward( + input_ids=torch.tensor([next_token], dtype=torch.long, device=self.device).unsqueeze(0), + cache_position=torch.tensor( + [pos_base + len(generated_tokens) - 1], + dtype=torch.long, + device=self.device, + ), + ) + next_token = torch.argmax(logits, dim=-1).item() + generated_tokens.append(next_token) + if next_token == self.eos_token_id: + break + + return generated_tokens if echo else generated_tokens[len(prompt_tokens) :] + + def text_generation( + self, + tokenizer: "PreTrainedTokenizer", + prompt: str, + echo: bool = True, + max_seq_len: Optional[int] = None, + ): + """ + Perform text generation task for a given prompt using the ExecuTorch model. + + Args: + tokenizer (`PreTrainedTokenizer`): + The tokenizer used to encode and decode the prompt and output. + prompt (`str`): + The text prompt to complete. + echo (`bool`, *optional*): + Whether to include prompt tokens in the generated output. Defaults to `True`. + max_seq_len (`int`, *optional*): + Maximum sequence length for the generated output. + Defaults to None and uses the model's `max_cache_size` attribute. + Will be truncated to maximal cache size if larger than `max_cache_size`. + """ + self.tokenizer = tokenizer + + # Sanity check + if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id: + raise ValueError( + f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}." + ) + if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id != self.eos_token_id: + raise ValueError( + f"The tokenizer's eos_token_id={self.tokenizer.eos_token_id} must be the same as the model's eos_token_id={self.eos_token_id}." + ) + + prompt_tokens = self.tokenizer.encode(prompt) + generated_tokens = self.generate( + prompt_tokens=prompt_tokens, + echo=echo, + max_seq_len=max_seq_len, + ) + return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) diff --git a/optimum/exporters/executorch/__init__.py b/optimum/exporters/executorch/__init__.py new file mode 100644 index 00000000000..cbdd2bfc0a9 --- /dev/null +++ b/optimum/exporters/executorch/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import _LazyModule + + +_import_structure = { + "convert": [ + "export_to_executorch", + ], + "recipe_registry": [ + "discover_recipes", + "register_recipe", + ], + "task_registry": [ + "discover_tasks", + "register_task", + ], + "__main__": ["main_export"], +} + +if TYPE_CHECKING: + from .__main__ import main_export + from .convert import export_to_executorch +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/optimum/exporters/executorch/__main__.py b/optimum/exporters/executorch/__main__.py new file mode 100644 index 00000000000..33a668b0674 --- /dev/null +++ b/optimum/exporters/executorch/__main__.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""Entry point to the optimum.exporters.executorch command line.""" + +import argparse +import os +import warnings +from pathlib import Path + +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from transformers.utils import is_torch_available + +from optimum.utils.import_utils import check_if_transformers_greater + +from ...commands.export.executorch import parse_args_executorch +from .convert import export_to_executorch +from .task_registry import discover_tasks, task_registry + + +if is_torch_available(): + pass + +from typing import Optional, Union + + +def main_export( + model_name_or_path: str, + task: str, + recipe: str, + output_dir: Union[str, Path], + cache_dir: str = HUGGINGFACE_HUB_CACHE, + trust_remote_code: bool = False, + pad_token_id: Optional[int] = None, + subfolder: str = "", + revision: str = "main", + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + **kwargs, +): + """ + Full-suite ExecuTorch export function, exporting **from a model ID on Hugging Face Hub or a local model repository**. + + Args: + model_name_or_path (`str`): + Model ID on huggingface.co or path on disk to the model repository to export. Example: `model_name_or_path="meta-llama/Llama-3.2-1B"` or `mode_name_or_path="/path/to/model_folder`. + task (`str`): + The task to export the model for, e.g. "text-generation". + recipe (`str`): + The recipe to use to do the export, e.g. "xnnpack". + output_dir (`Union[str, Path]`): + Path indicating the directory where to store the generated ExecuTorch model. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + trust_remote_code (`bool`, defaults to `False`): + Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories + you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the + model repository. + pad_token_id (`Optional[int]`, defaults to `None`): + This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + **kwargs: + Additional configuration options to tasks and recipes. + + Example usage: + ```python + >>> from optimum.exporters.executorch import main_export + + >>> main_export("meta-llama/Llama-3.2-1B", "text-generation", "xnnpack", "meta_llama3_2_1b/") + ``` + """ + + if not check_if_transformers_greater("4.46"): + raise ValueError( + "The minimum Transformers version compatible with ExecuTorch is 4.46.0. Please upgrade to Transformers 4.46.0 or later." + ) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + # Dynamically discover and import registered tasks + discover_tasks() + + # Load the model for specific task + try: + task_func = task_registry.get(task) + except KeyError as e: + raise RuntimeError(f"The task '{task}' isn't registered. Detailed error: {e}") + + model = task_func(model_name_or_path, **kwargs) + + if task == "text-generation": + from transformers.integrations.executorch import TorchExportableModuleWithStaticCache + + model = TorchExportableModuleWithStaticCache(model) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + return export_to_executorch( + model=model, + task=task, + recipe=recipe, + output_dir=output_dir, + **kwargs, + ) + + +def main(): + parser = argparse.ArgumentParser("Hugging Face Optimum ExecuTorch exporter") + + parse_args_executorch(parser) + + # Retrieve CLI arguments + args = parser.parse_args() + + main_export( + model_name_or_path=args.model, + output_dir=args.output_dir, + task=args.task, + recipe=args.recipe, + cache_dir=args.cache_dir, + trust_remote_code=args.trust_remote_code, + pad_token_id=args.pad_token_id, + ) + + +if __name__ == "__main__": + main() diff --git a/optimum/exporters/executorch/convert.py b/optimum/exporters/executorch/convert.py new file mode 100644 index 00000000000..f50a4b54a96 --- /dev/null +++ b/optimum/exporters/executorch/convert.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""ExecuTorch model check and export functions.""" + +import logging +import os +from pathlib import Path +from typing import Union + +from transformers.utils import is_torch_available + +from optimum.utils.import_utils import check_if_transformers_greater + +from .recipe_registry import discover_recipes, recipe_registry + + +if is_torch_available(): + from transformers.modeling_utils import PreTrainedModel + +if check_if_transformers_greater("4.46"): + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + ) + +logger = logging.getLogger(__name__) + + +def export_to_executorch( + model: Union["PreTrainedModel", "TorchExportableModuleWithStaticCache"], + task: str, + recipe: str, + output_dir: Union[str, Path], + **kwargs, +): + """ + Export a pre-trained PyTorch model to the ExecuTorch format using a specified recipe. + + This function facilitates the transformation of a PyTorch model into an optimized ExecuTorch program. + + Args: + model (`Union["PreTrainedModel", "TorchExportableModuleWithStaticCache"]`): + A PyTorch model to be exported. This can be a standard HuggingFace `PreTrainedModel` or a wrapped + module like `TorchExportableModuleWithStaticCache` for text generation task. + task (`str`): + The specific task the exported model will perform, e.g., "text-generation". + recipe (`str`): + The recipe to guide the export process, e.g., "xnnpack". Recipes define the optimization and lowering steps. + Will raise an exception if the specified recipe is not registered in the recipe registry. + output_dir (`Union[str, Path]`): + Path to the directory where the resulting ExecuTorch model will be saved. + **kwargs: + Additional configuration options passed to the recipe. + + Returns: + `ExecuTorchProgram`: + The lowered ExecuTorch program object. + + Notes: + - The function uses a dynamic recipe discovery mechanism to identify and import the specified recipe. + - The exported model is stored in the specified output directory with the fixed filename `model.pte`. + - The resulting ExecuTorch program is serialized and saved to the output directory. + """ + + # Dynamically discover and import registered recipes + discover_recipes() + + # Export and lower the model to ExecuTorch with the recipe + try: + recipe_func = recipe_registry.get(recipe) + except KeyError as e: + raise RuntimeError(f"The recipe '{recipe}' isn't registered. Detailed error: {e}") + + executorch_prog = recipe_func(model, task, **kwargs) + + full_path = os.path.join(f"{output_dir}", "model.pte") + with open(full_path, "wb") as f: + executorch_prog.write_to_file(f) + logging.info(f"Saved exported program to {full_path}") + + return executorch_prog diff --git a/optimum/exporters/executorch/recipe_registry.py b/optimum/exporters/executorch/recipe_registry.py new file mode 100644 index 00000000000..2eb728b7573 --- /dev/null +++ b/optimum/exporters/executorch/recipe_registry.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import importlib +import logging +import pkgutil + + +logger = logging.getLogger(__name__) + +recipe_registry = {} + +package_name = "optimum.exporters.executorch.recipes" + + +def register_recipe(recipe_name): + """ + Decorator to register a recipe for exporting and lowering an ExecuTorch model under a specific name. + + Args: + recipe_name (`str`): + The name of the recipe to associate with a callable recipe. + + Returns: + `Callable`: + The original function wrapped as a registered recipe. + + Example: + ```python + @register_recipe("my_new_recipe") + def my_new_recipe(...): + ... + ``` + """ + + def decorator(func): + recipe_registry[recipe_name] = func + return func + + return decorator + + +def discover_recipes(): + """ + Dynamically discovers and imports all recipe modules within the `optimum.exporters.executorch.recipes` package. + + Ensures recipes under `./recipes` directory are dynamically loaded without requiring manual imports. + + Notes: + New recipes **must** be added to the `./recipes` directory to be discovered and used by `main_export`. + Failure to do so will prevent dynamic discovery and registration. Recipes must also use the + `@register_recipe` decorator to be properly registered in the `recipe_registry`. + """ + package = importlib.import_module(package_name) + package_path = package.__path__ + + for _, module_name, _ in pkgutil.iter_modules(package_path): + logger.info(f"Importing {package_name}.{module_name}") + importlib.import_module(f"{package_name}.{module_name}") diff --git a/optimum/exporters/executorch/recipes/__init__.py b/optimum/exporters/executorch/recipes/__init__.py new file mode 100644 index 00000000000..30466c2d1a1 --- /dev/null +++ b/optimum/exporters/executorch/recipes/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. diff --git a/optimum/exporters/executorch/recipes/xnnpack.py b/optimum/exporters/executorch/recipes/xnnpack.py new file mode 100644 index 00000000000..d3b3a5d52aa --- /dev/null +++ b/optimum/exporters/executorch/recipes/xnnpack.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from typing import Union + +import torch +import torch.export._trace +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from torch.nn.attention import SDPBackend +from transformers import PreTrainedModel, TorchExportableModuleWithStaticCache + +from ..recipe_registry import register_recipe + + +@register_recipe("xnnpack") +def export_to_executorch_with_xnnpack( + model: Union[PreTrainedModel, TorchExportableModuleWithStaticCache], + task: str, + **kwargs, +): + """ + Export a PyTorch model to ExecuTorch w/ delegation to XNNPACK backend. + + This function also write metadata required by the ExecuTorch runtime to the model. + + Args: + model (Union[PreTrainedModel, TorchExportableModuleWithStaticCache]): + The PyTorch model to be exported to ExecuTorch. + task (str): + The task name to export the model for (e.g., "text-generation"). + **kwargs: + Additional keyword arguments for recipe-specific configurations. + + Returns: + ExecuTorchProgram: + The exported and optimized program for ExecuTorch. + """ + metadata = {} + if task == "text-generation": + example_input_ids = torch.tensor([[1]], dtype=torch.long) + example_cache_position = torch.tensor([0], dtype=torch.long) + + def _get_constant_methods(model: PreTrainedModel): + metadata = { + "get_dtype": 5 if model.config.torch_dtype == torch.float16 else 6, + "get_bos_id": model.config.bos_token_id, + "get_eos_id": model.config.eos_token_id, + "get_head_dim": model.config.hidden_size / model.config.num_attention_heads, + "get_max_batch_size": model.generation_config.cache_config.batch_size, + "get_max_seq_len": model.generation_config.cache_config.max_cache_len, + "get_n_kv_heads": model.config.num_key_value_heads, + "get_n_layers": model.config.num_hidden_layers, + "get_vocab_size": model.config.vocab_size, + "use_kv_cache": model.generation_config.use_cache, + } + return {k: v for k, v in metadata.items() if v is not None} + + metadata = _get_constant_methods(model if isinstance(model, PreTrainedModel) else model.model) + else: + # TODO: Prepare model inputs for other tasks + raise ValueError(f"Unsupported task '{task}'.") + + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + exported_program = torch.export._trace._export( + model, + args=(example_input_ids,), + kwargs={"cache_position": example_cache_position}, + pre_dispatch=False, + strict=True, + ) + + return to_edge_transform_and_lower( + exported_program, + partitioner=[XnnpackPartitioner()], + compile_config=EdgeCompileConfig( + _skip_dim_order=True, + ), + constant_methods=metadata, + ).to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + ), + ) diff --git a/optimum/exporters/executorch/task_registry.py b/optimum/exporters/executorch/task_registry.py new file mode 100644 index 00000000000..fdc34f0359a --- /dev/null +++ b/optimum/exporters/executorch/task_registry.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import importlib +import logging +import pkgutil + + +logger = logging.getLogger(__name__) + +task_registry = {} + +package_name = "optimum.exporters.executorch.tasks" + + +def register_task(task_name): + """ + Decorator to register a task under a specific name. + + Args: + task_name (`str`): + The name of the task to associate with a callable task. + + Returns: + `Callable`: + The original function wrapped as a registered task. + + Example: + ```python + @register_task("my_new_task") + def my_new_task(...): + ... + ``` + """ + + def decorator(func): + task_registry[task_name] = func + return func + + return decorator + + +def discover_tasks(): + """ + Dynamically discovers and imports all task modules within the `optimum.exporters.executorch.tasks` package. + + Ensures tasks under `./tasks` directory are dynamically loaded without requiring manual imports. + + Notes: + New tasks **must** be added to the `./tasks` directory to be discovered and used by `main_export`. + Failure to do so will prevent dynamic discovery and registration. Tasks must also use the + `@register_task` decorator to be properly registered in the `task_registry`. + """ + package = importlib.import_module(package_name) + package_path = package.__path__ + + for _, module_name, _ in pkgutil.iter_modules(package_path): + logger.info(f"Importing {package_name}.{module_name}") + importlib.import_module(f"{package_name}.{module_name}") diff --git a/optimum/exporters/executorch/tasks/__init__.py b/optimum/exporters/executorch/tasks/__init__.py new file mode 100644 index 00000000000..30466c2d1a1 --- /dev/null +++ b/optimum/exporters/executorch/tasks/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. diff --git a/optimum/exporters/executorch/tasks/causal_lm.py b/optimum/exporters/executorch/tasks/causal_lm.py new file mode 100644 index 00000000000..b02da8b319e --- /dev/null +++ b/optimum/exporters/executorch/tasks/causal_lm.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from transformers import AutoModelForCausalLM, GenerationConfig + +from ..task_registry import register_task + + +@register_task("text-generation") +def load_causal_lm_model(model_name_or_path: str, **kwargs): + """ + Loads a causal language model for text generation and registers it under the task + 'text-generation' using Hugging Face's AutoModelForCausalLM. + + Args: + model_name_or_path (str): + Model ID on huggingface.co or path on disk to the model repository to export. For example: + `model_name_or_path="meta-llama/Llama-3.2-1B"` or `mode_name_or_path="/path/to/model_folder` + **kwargs: + Additional configuration options for the model: + - dtype (str, optional): + Data type for model weights (default: "float32"). + Options include "float16" and "bfloat16". + - attn_implementation (str, optional): + Attention mechanism implementation (default: "sdpa"). + - cache_implementation (str, optional): + Cache management strategy (default: "static"). + - max_length (int, optional): + Maximum sequence length for generation (default: 2048). + + Returns: + transformers.PreTrainedModel: + An instance of a model subclass (e.g., Llama, Gemma) with the configuration for exporting + and lowering to ExecuTorch. + """ + device = "cpu" + batch_size = 1 + dtype = kwargs.get("dtype", "float32") + attn_implementation = kwargs.get("attn_implementation", "sdpa") + cache_implementation = kwargs.get("cache_implementation", "static") + max_length = kwargs.get("max_length", 2048) + + return AutoModelForCausalLM.from_pretrained( + model_name_or_path, + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_length, + }, + ), + ) diff --git a/optimum/onnxruntime/runs/__init__.py b/optimum/onnxruntime/runs/__init__.py index 1d982949344..d21db2a4aca 100644 --- a/optimum/onnxruntime/runs/__init__.py +++ b/optimum/onnxruntime/runs/__init__.py @@ -110,9 +110,9 @@ def __init__(self, run_config): model_class = FeaturesManager.get_model_class_for_feature(get_autoclass_name(self.task)) self.torch_model = model_class.from_pretrained(run_config["model_name_or_path"]) - self.return_body[ - "model_type" - ] = self.torch_model.config.model_type # return_body is initialized in parent class + self.return_body["model_type"] = ( + self.torch_model.config.model_type + ) # return_body is initialized in parent class def _launch_time(self, trial): batch_size = trial.suggest_categorical("batch_size", self.batch_sizes) diff --git a/setup.py b/setup.py index 6736085943a..bb5bcc11d43 100644 --- a/setup.py +++ b/setup.py @@ -85,6 +85,10 @@ "datasets<=2.16", "transformers>=4.36,<4.38", ], + "exporters-executorch": [ + "executorch>=0.4.0", + "transformers>=4.46", + ], "diffusers": ["diffusers"], "intel": "optimum-intel>=1.18.0", "openvino": "optimum-intel[openvino]>=1.18.0", diff --git a/tests/executorch/export/__init__.py b/tests/executorch/export/__init__.py new file mode 100644 index 00000000000..fdc02578672 --- /dev/null +++ b/tests/executorch/export/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/executorch/export/test_exporters_executorch.py b/tests/executorch/export/test_exporters_executorch.py new file mode 100644 index 00000000000..a4521bc0183 --- /dev/null +++ b/tests/executorch/export/test_exporters_executorch.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import tempfile +import unittest + +import pytest +from transformers.testing_utils import slow + + +class TestExportToExecuTorchCLI(unittest.TestCase): + def test_helps_no_raise(self): + subprocess.run( + "optimum-cli export executorch --help", + shell=True, + check=True, + ) + + @slow + @pytest.mark.run_slow + def test_llama3_2_1b_export_to_executorch(self): + model_id = "meta-llama/Llama-3.2-1B" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) + + @slow + @pytest.mark.run_slow + def test_llama3_2_3b_export_to_executorch(self): + model_id = "meta-llama/Llama-3.2-3B" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) + + @slow + @pytest.mark.run_slow + def test_qwen2_5_export_to_executorch(self): + model_id = "Qwen/Qwen2.5-0.5B" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) + + @slow + @pytest.mark.run_slow + def test_gemma2_export_to_executorch(self): + model_id = "google/gemma-2-2b" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) + + @slow + @pytest.mark.run_slow + def test_gemma_export_to_executorch(self): + model_id = "google/gemma-2b" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) + + @slow + @pytest.mark.run_slow + def test_olmo_export_to_executorch(self): + model_id = "allenai/OLMo-1B-hf" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) diff --git a/tests/executorch/runtime/__init__.py b/tests/executorch/runtime/__init__.py new file mode 100644 index 00000000000..fdc02578672 --- /dev/null +++ b/tests/executorch/runtime/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/executorch/runtime/test_modeling.py b/tests/executorch/runtime/test_modeling.py new file mode 100644 index 00000000000..88caf81b6d5 --- /dev/null +++ b/tests/executorch/runtime/test_modeling.py @@ -0,0 +1,207 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import pytest +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from transformers import AutoTokenizer +from transformers.testing_utils import slow + +from optimum.executorchruntime import ExecuTorchModelForCausalLM + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + def test_load_model_from_hub(self): + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path="meta-llama/Llama-3.2-1B", + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + @slow + @pytest.mark.run_slow + def test_load_model_from_local_path(self): + from optimum.exporters.executorch import main_export + + model_id = "meta-llama/Llama-3.2-1B" + task = "text-generation" + recipe = "xnnpack" + + with tempfile.TemporaryDirectory() as tempdir: + # Export to a local dir + main_export( + model_name_or_path=model_id, + task=task, + recipe=recipe, + output_dir=tempdir, + ) + self.assertTrue(os.path.exists(f"{tempdir}/model.pte")) + + # Load the exported model from a local dir + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=tempdir, + export=False, + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + @slow + @pytest.mark.run_slow + def test_llama3_2_1b_text_generation_with_xnnpack(self): + model_id = "meta-llama/Llama-3.2-1B" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = "Simply put, the theory of relativity states that the laws of physics are the same in all inertial frames of reference." + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Simply put, the theory of relativity states that", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) + + @slow + @pytest.mark.run_slow + def test_llama3_2_3b_text_generation_with_xnnpack(self): + model_id = "meta-llama/Llama-3.2-3B" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = ( + "Simply put, the theory of relativity states that the speed of light is constant. This " + "means that no matter how fast you are traveling, the speed of light will always be " + "186,000 miles per second." + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Simply put, the theory of relativity states that", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) + + @slow + @pytest.mark.run_slow + def test_qwen2_5_text_generation_with_xnnpack(self): + model_id = "Qwen/Qwen2.5-0.5B" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = "My favourite condiment is iced tea. I love it with my breakfast, my lunch" + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="My favourite condiment is ", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) + + @slow + @pytest.mark.run_slow + def test_gemma2_text_generation_with_xnnpack(self): + model_id = "google/gemma-2-2b" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = "Hello I am doing a project for my school. I need help with my science homework" + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Hello I am doing a project for my school", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) + + @slow + @pytest.mark.run_slow + def test_gemma_text_generation_with_xnnpack(self): + model_id = "google/gemma-2b" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = "Hello I am doing a project for my school and I need to make a 3D model of a car." + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Hello I am doing a project for my school", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) + + @slow + @pytest.mark.run_slow + def test_olmo_text_generation_with_xnnpack(self): + model_id = "allenai/OLMo-1B-hf" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = ( + "Simply put, the theory of relativity states that the speed of light is the same in all directions." + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Simply put, the theory of relativity states that", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) From b947760fb8989d32e8a63cce0d68536de3dc0b6e Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 12 Dec 2024 15:07:08 +0100 Subject: [PATCH 02/30] Test I can push From 301bbcf9854825846c3b413970195ea1c8d29305 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 12 Dec 2024 15:16:56 +0100 Subject: [PATCH 03/30] Styling --- optimum/executorchruntime/modeling_executorch.py | 6 +----- optimum/onnxruntime/runs/__init__.py | 6 +++--- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/optimum/executorchruntime/modeling_executorch.py b/optimum/executorchruntime/modeling_executorch.py index 39c75a03863..b93309f6a48 100644 --- a/optimum/executorchruntime/modeling_executorch.py +++ b/optimum/executorchruntime/modeling_executorch.py @@ -17,7 +17,7 @@ import warnings from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, List, Optional, Union +from typing import List, Optional, Union import torch from executorch.extension.pybindings.portable_lib import ( @@ -35,10 +35,6 @@ from ..modeling_base import OptimizedModel -if TYPE_CHECKING: - from transformers import PretrainedConfig - - logger = logging.getLogger(__name__) diff --git a/optimum/onnxruntime/runs/__init__.py b/optimum/onnxruntime/runs/__init__.py index d21db2a4aca..1d982949344 100644 --- a/optimum/onnxruntime/runs/__init__.py +++ b/optimum/onnxruntime/runs/__init__.py @@ -110,9 +110,9 @@ def __init__(self, run_config): model_class = FeaturesManager.get_model_class_for_feature(get_autoclass_name(self.task)) self.torch_model = model_class.from_pretrained(run_config["model_name_or_path"]) - self.return_body["model_type"] = ( - self.torch_model.config.model_type - ) # return_body is initialized in parent class + self.return_body[ + "model_type" + ] = self.torch_model.config.model_type # return_body is initialized in parent class def _launch_time(self, trial): batch_size = trial.suggest_categorical("batch_size", self.batch_sizes) From dec05878afad039e7d8585daf337f3d3097cdb8b Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 12 Dec 2024 15:33:46 +0100 Subject: [PATCH 04/30] Use ungated models for the tests --- tests/executorch/export/test_exporters_executorch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/executorch/export/test_exporters_executorch.py b/tests/executorch/export/test_exporters_executorch.py index a4521bc0183..f2467105e4f 100644 --- a/tests/executorch/export/test_exporters_executorch.py +++ b/tests/executorch/export/test_exporters_executorch.py @@ -33,7 +33,7 @@ def test_helps_no_raise(self): @slow @pytest.mark.run_slow def test_llama3_2_1b_export_to_executorch(self): - model_id = "meta-llama/Llama-3.2-1B" + model_id = "NousResearch/Llama-3.2-1B" task = "text-generation" recipe = "xnnpack" with tempfile.TemporaryDirectory() as tempdir: @@ -47,7 +47,7 @@ def test_llama3_2_1b_export_to_executorch(self): @slow @pytest.mark.run_slow def test_llama3_2_3b_export_to_executorch(self): - model_id = "meta-llama/Llama-3.2-3B" + model_id = "NousResearch/Hermes-3-Llama-3.2-3B" task = "text-generation" recipe = "xnnpack" with tempfile.TemporaryDirectory() as tempdir: @@ -75,7 +75,7 @@ def test_qwen2_5_export_to_executorch(self): @slow @pytest.mark.run_slow def test_gemma2_export_to_executorch(self): - model_id = "google/gemma-2-2b" + model_id = "unsloth/gemma-2-2b-it" task = "text-generation" recipe = "xnnpack" with tempfile.TemporaryDirectory() as tempdir: @@ -89,7 +89,7 @@ def test_gemma2_export_to_executorch(self): @slow @pytest.mark.run_slow def test_gemma_export_to_executorch(self): - model_id = "google/gemma-2b" + model_id = "weqweasdas/RM-Gemma-2B" task = "text-generation" recipe = "xnnpack" with tempfile.TemporaryDirectory() as tempdir: From 32d7d7901de4699a681496b694d4f0fd7e59c107 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 12 Dec 2024 16:44:57 +0100 Subject: [PATCH 05/30] Test if it OOMs --- tests/executorch/export/test_exporters_executorch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/executorch/export/test_exporters_executorch.py b/tests/executorch/export/test_exporters_executorch.py index f2467105e4f..ddce086b80b 100644 --- a/tests/executorch/export/test_exporters_executorch.py +++ b/tests/executorch/export/test_exporters_executorch.py @@ -74,6 +74,7 @@ def test_qwen2_5_export_to_executorch(self): @slow @pytest.mark.run_slow + @pytest.mark.skip def test_gemma2_export_to_executorch(self): model_id = "unsloth/gemma-2-2b-it" task = "text-generation" @@ -88,6 +89,7 @@ def test_gemma2_export_to_executorch(self): @slow @pytest.mark.run_slow + @pytest.mark.skip def test_gemma_export_to_executorch(self): model_id = "weqweasdas/RM-Gemma-2B" task = "text-generation" From eaf92c6811b0744a6ae0eeef037abaae34db483d Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 12 Dec 2024 17:12:41 +0100 Subject: [PATCH 06/30] Fix doc --- .../exporters/executorch/package_reference/configuration.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/exporters/executorch/package_reference/configuration.mdx b/docs/source/exporters/executorch/package_reference/configuration.mdx index b7a10b80419..ac1f2278184 100644 --- a/docs/source/exporters/executorch/package_reference/configuration.mdx +++ b/docs/source/exporters/executorch/package_reference/configuration.mdx @@ -48,7 +48,7 @@ configurations allow users to: [[autodoc]] exporters.executorch.recipe_registry.register_recipe -[[autodoc]] exporters.executorch.recipes.xnnpack.export_to_executorch_with_xnnpack +[[autodoc]] exporters.executorch.recipe.xnnpack.export_to_executorch_with_xnnpack The combination of task and recipe configurations ensures that users can customize both the high-level task setup and the low-level export details to suit their deployment requirements. From 7aa1562bf5614a1d71736f61ab42f4eb121dd815 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 12 Dec 2024 17:20:12 +0100 Subject: [PATCH 07/30] Test with different instance --- .github/workflows/test_executorch_export.yml | 2 +- tests/executorch/export/test_exporters_executorch.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/test_executorch_export.yml b/.github/workflows/test_executorch_export.yml index eb8f995f71c..771da52ca5a 100644 --- a/.github/workflows/test_executorch_export.yml +++ b/.github/workflows/test_executorch_export.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: python-version: ['3.10', '3.11', '3.12'] - os: [ubuntu-20.04, macos-15] + os: [ubuntu-20.04-16-core, macos-15] runs-on: ${{ matrix.os }} steps: diff --git a/tests/executorch/export/test_exporters_executorch.py b/tests/executorch/export/test_exporters_executorch.py index ddce086b80b..f2467105e4f 100644 --- a/tests/executorch/export/test_exporters_executorch.py +++ b/tests/executorch/export/test_exporters_executorch.py @@ -74,7 +74,6 @@ def test_qwen2_5_export_to_executorch(self): @slow @pytest.mark.run_slow - @pytest.mark.skip def test_gemma2_export_to_executorch(self): model_id = "unsloth/gemma-2-2b-it" task = "text-generation" @@ -89,7 +88,6 @@ def test_gemma2_export_to_executorch(self): @slow @pytest.mark.run_slow - @pytest.mark.skip def test_gemma_export_to_executorch(self): model_id = "weqweasdas/RM-Gemma-2B" task = "text-generation" From 6d361cebc9d449ac153df078d5c4d8dfa4b2110d Mon Sep 17 00:00:00 2001 From: Guang Yang <42389959+guangy10@users.noreply.github.com> Date: Thu, 12 Dec 2024 09:01:37 -0800 Subject: [PATCH 08/30] Update configuration.mdx Fix auddoc in configuration.mdx --- .../exporters/executorch/package_reference/configuration.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/exporters/executorch/package_reference/configuration.mdx b/docs/source/exporters/executorch/package_reference/configuration.mdx index ac1f2278184..b7a10b80419 100644 --- a/docs/source/exporters/executorch/package_reference/configuration.mdx +++ b/docs/source/exporters/executorch/package_reference/configuration.mdx @@ -48,7 +48,7 @@ configurations allow users to: [[autodoc]] exporters.executorch.recipe_registry.register_recipe -[[autodoc]] exporters.executorch.recipe.xnnpack.export_to_executorch_with_xnnpack +[[autodoc]] exporters.executorch.recipes.xnnpack.export_to_executorch_with_xnnpack The combination of task and recipe configurations ensures that users can customize both the high-level task setup and the low-level export details to suit their deployment requirements. From 0d9322ad3112f04258f252e2b0d67dbd893a6c22 Mon Sep 17 00:00:00 2001 From: Github Executorch Date: Thu, 12 Dec 2024 16:14:47 -0800 Subject: [PATCH 09/30] Experiment to use 'require_read_token' for accessing gated models in test_modeling.py --- tests/executorch/runtime/test_modeling.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/executorch/runtime/test_modeling.py b/tests/executorch/runtime/test_modeling.py index 88caf81b6d5..ff4c96f9e89 100644 --- a/tests/executorch/runtime/test_modeling.py +++ b/tests/executorch/runtime/test_modeling.py @@ -20,7 +20,10 @@ import pytest from executorch.extension.pybindings.portable_lib import ExecuTorchModule from transformers import AutoTokenizer -from transformers.testing_utils import slow +from transformers.testing_utils import ( + require_read_token, + slow, +) from optimum.executorchruntime import ExecuTorchModelForCausalLM @@ -33,7 +36,7 @@ def __init__(self, *args, **kwargs): @pytest.mark.run_slow def test_load_model_from_hub(self): model = ExecuTorchModelForCausalLM.from_pretrained( - model_name_or_path="meta-llama/Llama-3.2-1B", + model_name_or_path="NousResearch/Llama-3.2-1B", export=True, task="text-generation", recipe="xnnpack", @@ -46,7 +49,7 @@ def test_load_model_from_hub(self): def test_load_model_from_local_path(self): from optimum.exporters.executorch import main_export - model_id = "meta-llama/Llama-3.2-1B" + model_id = "NousResearch/Llama-3.2-1B" task = "text-generation" recipe = "xnnpack" @@ -70,6 +73,7 @@ def test_load_model_from_local_path(self): @slow @pytest.mark.run_slow + @require_read_token def test_llama3_2_1b_text_generation_with_xnnpack(self): model_id = "meta-llama/Llama-3.2-1B" model = ExecuTorchModelForCausalLM.from_pretrained( @@ -92,6 +96,7 @@ def test_llama3_2_1b_text_generation_with_xnnpack(self): @slow @pytest.mark.run_slow + @require_read_token def test_llama3_2_3b_text_generation_with_xnnpack(self): model_id = "meta-llama/Llama-3.2-3B" model = ExecuTorchModelForCausalLM.from_pretrained( @@ -118,6 +123,7 @@ def test_llama3_2_3b_text_generation_with_xnnpack(self): @slow @pytest.mark.run_slow + @require_read_token def test_qwen2_5_text_generation_with_xnnpack(self): model_id = "Qwen/Qwen2.5-0.5B" model = ExecuTorchModelForCausalLM.from_pretrained( @@ -140,6 +146,7 @@ def test_qwen2_5_text_generation_with_xnnpack(self): @slow @pytest.mark.run_slow + @require_read_token def test_gemma2_text_generation_with_xnnpack(self): model_id = "google/gemma-2-2b" model = ExecuTorchModelForCausalLM.from_pretrained( @@ -162,6 +169,7 @@ def test_gemma2_text_generation_with_xnnpack(self): @slow @pytest.mark.run_slow + @require_read_token def test_gemma_text_generation_with_xnnpack(self): model_id = "google/gemma-2b" model = ExecuTorchModelForCausalLM.from_pretrained( @@ -184,6 +192,7 @@ def test_gemma_text_generation_with_xnnpack(self): @slow @pytest.mark.run_slow + @require_read_token def test_olmo_text_generation_with_xnnpack(self): model_id = "allenai/OLMo-1B-hf" model = ExecuTorchModelForCausalLM.from_pretrained( From 3c5f7579929d4fd5b803c40bf9aec74879ae47e2 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 17 Dec 2024 13:10:08 +0100 Subject: [PATCH 10/30] Test with public models and disable linux CI --- .github/workflows/test_executorch_export.yml | 2 +- .github/workflows/test_executorch_runtime.yml | 2 +- tests/executorch/runtime/test_modeling.py | 16 ++++++---------- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/.github/workflows/test_executorch_export.yml b/.github/workflows/test_executorch_export.yml index 771da52ca5a..1571cd0cffb 100644 --- a/.github/workflows/test_executorch_export.yml +++ b/.github/workflows/test_executorch_export.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: python-version: ['3.10', '3.11', '3.12'] - os: [ubuntu-20.04-16-core, macos-15] + os: [macos-15] runs-on: ${{ matrix.os }} steps: diff --git a/.github/workflows/test_executorch_runtime.yml b/.github/workflows/test_executorch_runtime.yml index f7e3abcceff..3aea14f4ee8 100644 --- a/.github/workflows/test_executorch_runtime.yml +++ b/.github/workflows/test_executorch_runtime.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: python-version: ['3.10', '3.11', '3.12'] - os: [ubuntu-20.04, macos-15] + os: [macos-15] runs-on: ${{ matrix.os }} steps: diff --git a/tests/executorch/runtime/test_modeling.py b/tests/executorch/runtime/test_modeling.py index ff4c96f9e89..a73c82a4965 100644 --- a/tests/executorch/runtime/test_modeling.py +++ b/tests/executorch/runtime/test_modeling.py @@ -73,9 +73,8 @@ def test_load_model_from_local_path(self): @slow @pytest.mark.run_slow - @require_read_token def test_llama3_2_1b_text_generation_with_xnnpack(self): - model_id = "meta-llama/Llama-3.2-1B" + model_id = "NousResearch/Llama-3.2-1B" model = ExecuTorchModelForCausalLM.from_pretrained( model_name_or_path=model_id, export=True, @@ -96,9 +95,8 @@ def test_llama3_2_1b_text_generation_with_xnnpack(self): @slow @pytest.mark.run_slow - @require_read_token def test_llama3_2_3b_text_generation_with_xnnpack(self): - model_id = "meta-llama/Llama-3.2-3B" + model_id = "NousResearch/Hermes-3-Llama-3.2-3B" model = ExecuTorchModelForCausalLM.from_pretrained( model_name_or_path=model_id, export=True, @@ -123,7 +121,6 @@ def test_llama3_2_3b_text_generation_with_xnnpack(self): @slow @pytest.mark.run_slow - @require_read_token def test_qwen2_5_text_generation_with_xnnpack(self): model_id = "Qwen/Qwen2.5-0.5B" model = ExecuTorchModelForCausalLM.from_pretrained( @@ -146,9 +143,9 @@ def test_qwen2_5_text_generation_with_xnnpack(self): @slow @pytest.mark.run_slow - @require_read_token def test_gemma2_text_generation_with_xnnpack(self): - model_id = "google/gemma-2-2b" + # model_id = "google/gemma-2-2b" + model_id = "unsloth/gemma-2-2b-it" model = ExecuTorchModelForCausalLM.from_pretrained( model_name_or_path=model_id, export=True, @@ -169,9 +166,9 @@ def test_gemma2_text_generation_with_xnnpack(self): @slow @pytest.mark.run_slow - @require_read_token def test_gemma_text_generation_with_xnnpack(self): - model_id = "google/gemma-2b" + # model_id = "google/gemma-2b" + model_id = "weqweasdas/RM-Gemma-2B" model = ExecuTorchModelForCausalLM.from_pretrained( model_name_or_path=model_id, export=True, @@ -192,7 +189,6 @@ def test_gemma_text_generation_with_xnnpack(self): @slow @pytest.mark.run_slow - @require_read_token def test_olmo_text_generation_with_xnnpack(self): model_id = "allenai/OLMo-1B-hf" model = ExecuTorchModelForCausalLM.from_pretrained( From c36e2e2d060b3ebb79b7a987a08e454e90deda72 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 17 Dec 2024 19:48:40 +0100 Subject: [PATCH 11/30] Trying to fix the doc --- optimum/exporters/executorch/__init__.py | 3 +++ tests/executorch/runtime/test_modeling.py | 13 +++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/optimum/exporters/executorch/__init__.py b/optimum/exporters/executorch/__init__.py index cbdd2bfc0a9..67590723703 100644 --- a/optimum/exporters/executorch/__init__.py +++ b/optimum/exporters/executorch/__init__.py @@ -27,6 +27,9 @@ "discover_tasks", "register_task", ], + "xnnpack": [ + "export_to_executorch_with_xnnpack", + ], "__main__": ["main_export"], } diff --git a/tests/executorch/runtime/test_modeling.py b/tests/executorch/runtime/test_modeling.py index a73c82a4965..d8c6e1bb498 100644 --- a/tests/executorch/runtime/test_modeling.py +++ b/tests/executorch/runtime/test_modeling.py @@ -21,7 +21,6 @@ from executorch.extension.pybindings.portable_lib import ExecuTorchModule from transformers import AutoTokenizer from transformers.testing_utils import ( - require_read_token, slow, ) @@ -107,9 +106,9 @@ def test_llama3_2_3b_text_generation_with_xnnpack(self): self.assertIsInstance(model.model, ExecuTorchModule) EXPECTED_GENERATED_TEXT = ( - "Simply put, the theory of relativity states that the speed of light is constant. This " - "means that no matter how fast you are traveling, the speed of light will always be " - "186,000 miles per second." + "Simply put, the theory of relativity states that time is relative and can be affected " + "by an object's speed. This theory was developed by Albert Einstein in the early 20th " + "century. The theory has two parts" ) tokenizer = AutoTokenizer.from_pretrained(model_id) generated_text = model.text_generation( @@ -155,7 +154,9 @@ def test_gemma2_text_generation_with_xnnpack(self): self.assertIsInstance(model, ExecuTorchModelForCausalLM) self.assertIsInstance(model.model, ExecuTorchModule) - EXPECTED_GENERATED_TEXT = "Hello I am doing a project for my school. I need help with my science homework" + EXPECTED_GENERATED_TEXT = ( + "Hello I am doing a project for my school and I need to make sure it is a great to be creative and I can!" + ) tokenizer = AutoTokenizer.from_pretrained(model_id) generated_text = model.text_generation( tokenizer=tokenizer, @@ -178,7 +179,7 @@ def test_gemma_text_generation_with_xnnpack(self): self.assertIsInstance(model, ExecuTorchModelForCausalLM) self.assertIsInstance(model.model, ExecuTorchModule) - EXPECTED_GENERATED_TEXT = "Hello I am doing a project for my school and I need to make a 3D model of a car." + EXPECTED_GENERATED_TEXT = "Hello I am doing a project for my school and I need to write a report on the history of the United States." tokenizer = AutoTokenizer.from_pretrained(model_id) generated_text = model.text_generation( tokenizer=tokenizer, From f0c76d496cb460e8c89af95c2f6eebef584011b4 Mon Sep 17 00:00:00 2001 From: Github Executorch Date: Tue, 17 Dec 2024 20:27:33 -0800 Subject: [PATCH 12/30] Split modeling tests to separate files and CI jobs --- .github/workflows/test_executorch_runtime.yml | 9 +- tests/executorch/runtime/test_modeling.py | 142 ------------------ .../executorch/runtime/test_modeling_gemma.py | 56 +++++++ .../runtime/test_modeling_gemma2.py | 58 +++++++ .../executorch/runtime/test_modeling_llama.py | 84 +++++++++++ .../executorch/runtime/test_modeling_olmo.py | 56 +++++++ .../executorch/runtime/test_modeling_qwen2.py | 54 +++++++ 7 files changed, 316 insertions(+), 143 deletions(-) create mode 100644 tests/executorch/runtime/test_modeling_gemma.py create mode 100644 tests/executorch/runtime/test_modeling_gemma2.py create mode 100644 tests/executorch/runtime/test_modeling_llama.py create mode 100644 tests/executorch/runtime/test_modeling_olmo.py create mode 100644 tests/executorch/runtime/test_modeling_qwen2.py diff --git a/.github/workflows/test_executorch_runtime.yml b/.github/workflows/test_executorch_runtime.yml index 3aea14f4ee8..d5bbc0f8eaa 100644 --- a/.github/workflows/test_executorch_runtime.yml +++ b/.github/workflows/test_executorch_runtime.yml @@ -17,6 +17,13 @@ jobs: matrix: python-version: ['3.10', '3.11', '3.12'] os: [macos-15] + test-modeling: + - test_modeling_gemma2.py + - test_modeling_gemma.py + - test_modeling_llama.py + - test_modeling_olmo.py + - test_modeling.py + - test_modeling_qwen2.py runs-on: ${{ matrix.os }} steps: @@ -32,4 +39,4 @@ jobs: - name: Run tests working-directory: tests run: | - RUN_SLOW=1 pytest executorch/runtime/test_*.py -s -vvvv --durations=0 + RUN_SLOW=1 pytest executorch/runtime/${{ matrix.test-modeling }} -s -vvvv --durations=0 diff --git a/tests/executorch/runtime/test_modeling.py b/tests/executorch/runtime/test_modeling.py index d8c6e1bb498..6593da7a8c7 100644 --- a/tests/executorch/runtime/test_modeling.py +++ b/tests/executorch/runtime/test_modeling.py @@ -69,145 +69,3 @@ def test_load_model_from_local_path(self): ) self.assertIsInstance(model, ExecuTorchModelForCausalLM) self.assertIsInstance(model.model, ExecuTorchModule) - - @slow - @pytest.mark.run_slow - def test_llama3_2_1b_text_generation_with_xnnpack(self): - model_id = "NousResearch/Llama-3.2-1B" - model = ExecuTorchModelForCausalLM.from_pretrained( - model_name_or_path=model_id, - export=True, - task="text-generation", - recipe="xnnpack", - ) - self.assertIsInstance(model, ExecuTorchModelForCausalLM) - self.assertIsInstance(model.model, ExecuTorchModule) - - EXPECTED_GENERATED_TEXT = "Simply put, the theory of relativity states that the laws of physics are the same in all inertial frames of reference." - tokenizer = AutoTokenizer.from_pretrained(model_id) - generated_text = model.text_generation( - tokenizer=tokenizer, - prompt="Simply put, the theory of relativity states that", - max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), - ) - self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) - - @slow - @pytest.mark.run_slow - def test_llama3_2_3b_text_generation_with_xnnpack(self): - model_id = "NousResearch/Hermes-3-Llama-3.2-3B" - model = ExecuTorchModelForCausalLM.from_pretrained( - model_name_or_path=model_id, - export=True, - task="text-generation", - recipe="xnnpack", - ) - self.assertIsInstance(model, ExecuTorchModelForCausalLM) - self.assertIsInstance(model.model, ExecuTorchModule) - - EXPECTED_GENERATED_TEXT = ( - "Simply put, the theory of relativity states that time is relative and can be affected " - "by an object's speed. This theory was developed by Albert Einstein in the early 20th " - "century. The theory has two parts" - ) - tokenizer = AutoTokenizer.from_pretrained(model_id) - generated_text = model.text_generation( - tokenizer=tokenizer, - prompt="Simply put, the theory of relativity states that", - max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), - ) - self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) - - @slow - @pytest.mark.run_slow - def test_qwen2_5_text_generation_with_xnnpack(self): - model_id = "Qwen/Qwen2.5-0.5B" - model = ExecuTorchModelForCausalLM.from_pretrained( - model_name_or_path=model_id, - export=True, - task="text-generation", - recipe="xnnpack", - ) - self.assertIsInstance(model, ExecuTorchModelForCausalLM) - self.assertIsInstance(model.model, ExecuTorchModule) - - EXPECTED_GENERATED_TEXT = "My favourite condiment is iced tea. I love it with my breakfast, my lunch" - tokenizer = AutoTokenizer.from_pretrained(model_id) - generated_text = model.text_generation( - tokenizer=tokenizer, - prompt="My favourite condiment is ", - max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), - ) - self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) - - @slow - @pytest.mark.run_slow - def test_gemma2_text_generation_with_xnnpack(self): - # model_id = "google/gemma-2-2b" - model_id = "unsloth/gemma-2-2b-it" - model = ExecuTorchModelForCausalLM.from_pretrained( - model_name_or_path=model_id, - export=True, - task="text-generation", - recipe="xnnpack", - ) - self.assertIsInstance(model, ExecuTorchModelForCausalLM) - self.assertIsInstance(model.model, ExecuTorchModule) - - EXPECTED_GENERATED_TEXT = ( - "Hello I am doing a project for my school and I need to make sure it is a great to be creative and I can!" - ) - tokenizer = AutoTokenizer.from_pretrained(model_id) - generated_text = model.text_generation( - tokenizer=tokenizer, - prompt="Hello I am doing a project for my school", - max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), - ) - self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) - - @slow - @pytest.mark.run_slow - def test_gemma_text_generation_with_xnnpack(self): - # model_id = "google/gemma-2b" - model_id = "weqweasdas/RM-Gemma-2B" - model = ExecuTorchModelForCausalLM.from_pretrained( - model_name_or_path=model_id, - export=True, - task="text-generation", - recipe="xnnpack", - ) - self.assertIsInstance(model, ExecuTorchModelForCausalLM) - self.assertIsInstance(model.model, ExecuTorchModule) - - EXPECTED_GENERATED_TEXT = "Hello I am doing a project for my school and I need to write a report on the history of the United States." - tokenizer = AutoTokenizer.from_pretrained(model_id) - generated_text = model.text_generation( - tokenizer=tokenizer, - prompt="Hello I am doing a project for my school", - max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), - ) - self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) - - @slow - @pytest.mark.run_slow - def test_olmo_text_generation_with_xnnpack(self): - model_id = "allenai/OLMo-1B-hf" - model = ExecuTorchModelForCausalLM.from_pretrained( - model_name_or_path=model_id, - export=True, - task="text-generation", - recipe="xnnpack", - ) - self.assertIsInstance(model, ExecuTorchModelForCausalLM) - self.assertIsInstance(model.model, ExecuTorchModule) - - EXPECTED_GENERATED_TEXT = ( - "Simply put, the theory of relativity states that the speed of light is the same in all directions." - ) - tokenizer = AutoTokenizer.from_pretrained(model_id) - generated_text = model.text_generation( - tokenizer=tokenizer, - prompt="Simply put, the theory of relativity states that", - max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), - ) - self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) diff --git a/tests/executorch/runtime/test_modeling_gemma.py b/tests/executorch/runtime/test_modeling_gemma.py new file mode 100644 index 00000000000..08f80d4e574 --- /dev/null +++ b/tests/executorch/runtime/test_modeling_gemma.py @@ -0,0 +1,56 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import pytest +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from transformers import AutoTokenizer +from transformers.testing_utils import ( + slow, +) + +from optimum.executorchruntime import ExecuTorchModelForCausalLM + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + def test_gemma_text_generation_with_xnnpack(self): + # TODO: Swithc to use google/gemma-2b once https://github.com/huggingface/optimum/issues/2127 is fixed + # model_id = "google/gemma-2b" + model_id = "weqweasdas/RM-Gemma-2B" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = "Hello I am doing a project for my school and I need to write a report on the history of the United States." + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Hello I am doing a project for my school", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) diff --git a/tests/executorch/runtime/test_modeling_gemma2.py b/tests/executorch/runtime/test_modeling_gemma2.py new file mode 100644 index 00000000000..6878daa774f --- /dev/null +++ b/tests/executorch/runtime/test_modeling_gemma2.py @@ -0,0 +1,58 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import pytest +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from transformers import AutoTokenizer +from transformers.testing_utils import ( + slow, +) + +from optimum.executorchruntime import ExecuTorchModelForCausalLM + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + def test_gemma2_text_generation_with_xnnpack(self): + # TODO: Swithc to use google/gemma-2-2b once https://github.com/huggingface/optimum/issues/2127 is fixed + # model_id = "google/gemma-2-2b" + model_id = "unsloth/gemma-2-2b-it" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = ( + "Hello I am doing a project for my school and I need to make sure it is a great to be creative and I can!" + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Hello I am doing a project for my school", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) diff --git a/tests/executorch/runtime/test_modeling_llama.py b/tests/executorch/runtime/test_modeling_llama.py new file mode 100644 index 00000000000..1834ee162d3 --- /dev/null +++ b/tests/executorch/runtime/test_modeling_llama.py @@ -0,0 +1,84 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import pytest +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from transformers import AutoTokenizer +from transformers.testing_utils import ( + slow, +) + +from optimum.executorchruntime import ExecuTorchModelForCausalLM + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + def test_llama3_2_1b_text_generation_with_xnnpack(self): + # TODO: Swithc to use meta-llama/Llama-3.2-1B once https://github.com/huggingface/optimum/issues/2127 is fixed + # model_id = "lama/Llama-3.2-1B" + model_id = "NousResearch/Llama-3.2-1B" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = "Simply put, the theory of relativity states that the laws of physics are the same in all inertial frames of reference." + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Simply put, the theory of relativity states that", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) + + @slow + @pytest.mark.run_slow + def test_llama3_2_3b_text_generation_with_xnnpack(self): + # TODO: Swithc to use meta-llama/Llama-3.2-3B once https://github.com/huggingface/optimum/issues/2127 is fixed + # model_id = "lama/Llama-3.2-3B" + model_id = "NousResearch/Hermes-3-Llama-3.2-3B" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = ( + "Simply put, the theory of relativity states that time is relative and can be affected " + "by an object's speed. This theory was developed by Albert Einstein in the early 20th " + "century. The theory has two parts" + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Simply put, the theory of relativity states that", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) diff --git a/tests/executorch/runtime/test_modeling_olmo.py b/tests/executorch/runtime/test_modeling_olmo.py new file mode 100644 index 00000000000..65c3045ad86 --- /dev/null +++ b/tests/executorch/runtime/test_modeling_olmo.py @@ -0,0 +1,56 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import pytest +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from transformers import AutoTokenizer +from transformers.testing_utils import ( + slow, +) + +from optimum.executorchruntime import ExecuTorchModelForCausalLM + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + def test_olmo_text_generation_with_xnnpack(self): + model_id = "allenai/OLMo-1B-hf" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = ( + "Simply put, the theory of relativity states that the speed of light is the same in all directions." + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Simply put, the theory of relativity states that", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) diff --git a/tests/executorch/runtime/test_modeling_qwen2.py b/tests/executorch/runtime/test_modeling_qwen2.py new file mode 100644 index 00000000000..d80a286b72d --- /dev/null +++ b/tests/executorch/runtime/test_modeling_qwen2.py @@ -0,0 +1,54 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import pytest +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from transformers import AutoTokenizer +from transformers.testing_utils import ( + slow, +) + +from optimum.executorchruntime import ExecuTorchModelForCausalLM + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + def test_qwen2_5_text_generation_with_xnnpack(self): + model_id = "Qwen/Qwen2.5-0.5B" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = "My favourite condiment is iced tea. I love it with my breakfast, my lunch" + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="My favourite condiment is ", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) From 9063a90f2053d157ad0e0759edba414c2a5362db Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 18 Dec 2024 10:15:50 +0100 Subject: [PATCH 13/30] Styling --- tests/executorch/runtime/test_modeling.py | 1 - tests/executorch/runtime/test_modeling_gemma.py | 2 -- tests/executorch/runtime/test_modeling_gemma2.py | 2 -- tests/executorch/runtime/test_modeling_llama.py | 2 -- tests/executorch/runtime/test_modeling_olmo.py | 2 -- tests/executorch/runtime/test_modeling_qwen2.py | 2 -- 6 files changed, 11 deletions(-) diff --git a/tests/executorch/runtime/test_modeling.py b/tests/executorch/runtime/test_modeling.py index 6593da7a8c7..c97b461403c 100644 --- a/tests/executorch/runtime/test_modeling.py +++ b/tests/executorch/runtime/test_modeling.py @@ -19,7 +19,6 @@ import pytest from executorch.extension.pybindings.portable_lib import ExecuTorchModule -from transformers import AutoTokenizer from transformers.testing_utils import ( slow, ) diff --git a/tests/executorch/runtime/test_modeling_gemma.py b/tests/executorch/runtime/test_modeling_gemma.py index 08f80d4e574..d54f0a6767e 100644 --- a/tests/executorch/runtime/test_modeling_gemma.py +++ b/tests/executorch/runtime/test_modeling_gemma.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import tempfile import unittest import pytest diff --git a/tests/executorch/runtime/test_modeling_gemma2.py b/tests/executorch/runtime/test_modeling_gemma2.py index 6878daa774f..b695da6d65b 100644 --- a/tests/executorch/runtime/test_modeling_gemma2.py +++ b/tests/executorch/runtime/test_modeling_gemma2.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import tempfile import unittest import pytest diff --git a/tests/executorch/runtime/test_modeling_llama.py b/tests/executorch/runtime/test_modeling_llama.py index 1834ee162d3..e91a96fd914 100644 --- a/tests/executorch/runtime/test_modeling_llama.py +++ b/tests/executorch/runtime/test_modeling_llama.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import tempfile import unittest import pytest diff --git a/tests/executorch/runtime/test_modeling_olmo.py b/tests/executorch/runtime/test_modeling_olmo.py index 65c3045ad86..aa57496f291 100644 --- a/tests/executorch/runtime/test_modeling_olmo.py +++ b/tests/executorch/runtime/test_modeling_olmo.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import tempfile import unittest import pytest diff --git a/tests/executorch/runtime/test_modeling_qwen2.py b/tests/executorch/runtime/test_modeling_qwen2.py index d80a286b72d..ef624a784ea 100644 --- a/tests/executorch/runtime/test_modeling_qwen2.py +++ b/tests/executorch/runtime/test_modeling_qwen2.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import tempfile import unittest import pytest From 09965ae097bf07f1f3b693be6053ef10d2d46f23 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 18 Dec 2024 11:35:03 +0100 Subject: [PATCH 14/30] Fix doc --- docs/source/_toctree.yml | 17 +++++++++++++++++ optimum/exporters/__init__.py | 1 + optimum/exporters/executorch/__init__.py | 7 +++++-- .../exporters/executorch/recipes/__init__.py | 2 ++ optimum/exporters/executorch/tasks/__init__.py | 2 ++ 5 files changed, 27 insertions(+), 2 deletions(-) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 8444da1b9a9..dc69564b045 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -81,6 +81,23 @@ title: Reference isExpanded: false title: "ONNX" + - sections: + - local: exporters/executorch/overview + title: Overview + - sections: + - local: exporters/executorch/usage_guides/export_a_model + title: Export a model to ExecuTorch + - local: exporters/executorch/usage_guides/contribute + title: Add support for exporting an architecture to ExecuTorch + title: How-to guides + - sections: + - local: exporters/executorch/package_reference/configuration + title: ExecuTorch configurations + - local: exporters/executorch/package_reference/export + title: Export functions + title: Reference + isExpanded: false + title: "ExecuTorch" - sections: - local: exporters/tflite/overview title: Overview diff --git a/optimum/exporters/__init__.py b/optimum/exporters/__init__.py index eef17dac7f7..7b08812a569 100644 --- a/optimum/exporters/__init__.py +++ b/optimum/exporters/__init__.py @@ -13,4 +13,5 @@ # See the License for the specific language governing permissions and # limitations under the License. from . import onnx # noqa +from . import executorch # noqa from .tasks import TasksManager # noqa diff --git a/optimum/exporters/executorch/__init__.py b/optimum/exporters/executorch/__init__.py index 67590723703..3409e69fcfb 100644 --- a/optimum/exporters/executorch/__init__.py +++ b/optimum/exporters/executorch/__init__.py @@ -27,8 +27,11 @@ "discover_tasks", "register_task", ], - "xnnpack": [ - "export_to_executorch_with_xnnpack", + "tasks": [ + "causal_lm", + ], + "recipes": [ + "xnnpack", ], "__main__": ["main_export"], } diff --git a/optimum/exporters/executorch/recipes/__init__.py b/optimum/exporters/executorch/recipes/__init__.py index 30466c2d1a1..a2e21cf3970 100644 --- a/optimum/exporters/executorch/recipes/__init__.py +++ b/optimum/exporters/executorch/recipes/__init__.py @@ -9,3 +9,5 @@ # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. + +from . import xnnpack diff --git a/optimum/exporters/executorch/tasks/__init__.py b/optimum/exporters/executorch/tasks/__init__.py index 30466c2d1a1..754a8241ca3 100644 --- a/optimum/exporters/executorch/tasks/__init__.py +++ b/optimum/exporters/executorch/tasks/__init__.py @@ -9,3 +9,5 @@ # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. + +from . import causal_lm From 8f60d665d8659633156ca483e7bc5664085ef340 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 18 Dec 2024 11:37:48 +0100 Subject: [PATCH 15/30] Disable LLama 3b since it ooms on GH instances --- tests/executorch/runtime/test_modeling_gemma.py | 2 +- tests/executorch/runtime/test_modeling_gemma2.py | 2 +- tests/executorch/runtime/test_modeling_llama.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/executorch/runtime/test_modeling_gemma.py b/tests/executorch/runtime/test_modeling_gemma.py index d54f0a6767e..0e4238bf8ee 100644 --- a/tests/executorch/runtime/test_modeling_gemma.py +++ b/tests/executorch/runtime/test_modeling_gemma.py @@ -32,7 +32,7 @@ def __init__(self, *args, **kwargs): @slow @pytest.mark.run_slow def test_gemma_text_generation_with_xnnpack(self): - # TODO: Swithc to use google/gemma-2b once https://github.com/huggingface/optimum/issues/2127 is fixed + # TODO: Switch to use google/gemma-2b once https://github.com/huggingface/optimum/issues/2127 is fixed # model_id = "google/gemma-2b" model_id = "weqweasdas/RM-Gemma-2B" model = ExecuTorchModelForCausalLM.from_pretrained( diff --git a/tests/executorch/runtime/test_modeling_gemma2.py b/tests/executorch/runtime/test_modeling_gemma2.py index b695da6d65b..22fe4ab60d7 100644 --- a/tests/executorch/runtime/test_modeling_gemma2.py +++ b/tests/executorch/runtime/test_modeling_gemma2.py @@ -32,7 +32,7 @@ def __init__(self, *args, **kwargs): @slow @pytest.mark.run_slow def test_gemma2_text_generation_with_xnnpack(self): - # TODO: Swithc to use google/gemma-2-2b once https://github.com/huggingface/optimum/issues/2127 is fixed + # TODO: Switch to use google/gemma-2-2b once https://github.com/huggingface/optimum/issues/2127 is fixed # model_id = "google/gemma-2-2b" model_id = "unsloth/gemma-2-2b-it" model = ExecuTorchModelForCausalLM.from_pretrained( diff --git a/tests/executorch/runtime/test_modeling_llama.py b/tests/executorch/runtime/test_modeling_llama.py index e91a96fd914..fb08a5615a5 100644 --- a/tests/executorch/runtime/test_modeling_llama.py +++ b/tests/executorch/runtime/test_modeling_llama.py @@ -32,7 +32,7 @@ def __init__(self, *args, **kwargs): @slow @pytest.mark.run_slow def test_llama3_2_1b_text_generation_with_xnnpack(self): - # TODO: Swithc to use meta-llama/Llama-3.2-1B once https://github.com/huggingface/optimum/issues/2127 is fixed + # TODO: Switch to use meta-llama/Llama-3.2-1B once https://github.com/huggingface/optimum/issues/2127 is fixed # model_id = "lama/Llama-3.2-1B" model_id = "NousResearch/Llama-3.2-1B" model = ExecuTorchModelForCausalLM.from_pretrained( @@ -55,8 +55,9 @@ def test_llama3_2_1b_text_generation_with_xnnpack(self): @slow @pytest.mark.run_slow + @pytest.mark.skip(reason="OOMs with macos-15 CI instances on GH.") def test_llama3_2_3b_text_generation_with_xnnpack(self): - # TODO: Swithc to use meta-llama/Llama-3.2-3B once https://github.com/huggingface/optimum/issues/2127 is fixed + # TODO: Switch to use meta-llama/Llama-3.2-3B once https://github.com/huggingface/optimum/issues/2127 is fixed # model_id = "lama/Llama-3.2-3B" model_id = "NousResearch/Hermes-3-Llama-3.2-3B" model = ExecuTorchModelForCausalLM.from_pretrained( From cbc7c43665b7a222bc83b1ca8d82fbb659e02652 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 18 Dec 2024 12:41:03 +0100 Subject: [PATCH 16/30] Updating torch for executorch in the doc building PR --- .github/workflows/build_main_documentation.yml | 1 + .github/workflows/build_pr_documentation.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/build_main_documentation.yml b/.github/workflows/build_main_documentation.yml index d38274f320a..79f5485d05b 100644 --- a/.github/workflows/build_main_documentation.yml +++ b/.github/workflows/build_main_documentation.yml @@ -180,6 +180,7 @@ jobs: mkdir -p optimum-doc-build/optimum && cd optimum-doc-build/optimum wget https://huggingface.co/datasets/hf-doc-build/doc-build/raw/main/optimum/_versions.yml cd ../.. + pip install torch -U make doc BUILD_DIR=optimum-doc-build VERSION=${{ env.VERSION }} COMMIT_SHA_OPTIMUM=${{ env.VERSION }} cd .. diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml index 6eb09aff304..82311ea76c6 100644 --- a/.github/workflows/build_pr_documentation.yml +++ b/.github/workflows/build_pr_documentation.yml @@ -114,6 +114,7 @@ jobs: run: | sudo docker system prune -a -f cd optimum + pip install torch -U make doc BUILD_DIR=optimum-doc-build VERSION=pr_$PR_NUMBER COMMIT_SHA_OPTIMUM=$COMMIT_SHA CLONE_URL=$PR_CLONE_URL cd .. From 6242721107a85818aa37b0d65fa42384ddc2ad59 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 18 Dec 2024 22:07:29 +0100 Subject: [PATCH 17/30] Test doc build --- .../workflows/build_main_documentation.yml | 1 - .github/workflows/build_pr_documentation.yml | 79 +++++++++---------- docs/Dockerfile | 2 +- 3 files changed, 40 insertions(+), 42 deletions(-) diff --git a/.github/workflows/build_main_documentation.yml b/.github/workflows/build_main_documentation.yml index 79f5485d05b..d38274f320a 100644 --- a/.github/workflows/build_main_documentation.yml +++ b/.github/workflows/build_main_documentation.yml @@ -180,7 +180,6 @@ jobs: mkdir -p optimum-doc-build/optimum && cd optimum-doc-build/optimum wget https://huggingface.co/datasets/hf-doc-build/doc-build/raw/main/optimum/_versions.yml cd ../.. - pip install torch -U make doc BUILD_DIR=optimum-doc-build VERSION=${{ env.VERSION }} COMMIT_SHA_OPTIMUM=${{ env.VERSION }} cd .. diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml index 82311ea76c6..4de5ff08a56 100644 --- a/.github/workflows/build_pr_documentation.yml +++ b/.github/workflows/build_pr_documentation.yml @@ -70,51 +70,50 @@ jobs: pip install black cd .. - - name: Make Habana documentation - run: | - sudo docker system prune -a -f - cd optimum-habana - make doc BUILD_DIR=habana-doc-build VERSION=pr_$PR_NUMBER - sudo mv habana-doc-build ../optimum - cd .. - - - name: Make Intel documentation - run: | - sudo docker system prune -a -f - cd optimum-intel - make doc BUILD_DIR=intel-doc-build VERSION=pr_$PR_NUMBER - sudo mv intel-doc-build ../optimum - cd .. - - # TODO: enable Furiosa doc build in PRs once archive.furiosa.ai is public - - name: Make Furiosa documentation - run: | - echo "For PRs we don't build Furiosa doc" - - - name: Make AMD documentation - run: | - sudo docker system prune -a -f - cd optimum-amd - make doc BUILD_DIR=amd-doc-build VERSION=pr_$PR_NUMBER - sudo mv amd-doc-build ../optimum - cd .. - - - name: Make TPU documentation - run: | - sudo docker system prune -a -f - source venv-doc/bin/activate - cd optimum-tpu - pip install -U pip - pip install . -f https://storage.googleapis.com/libtpu-releases/index.html - doc-builder build optimum.tpu docs/source/ --build_dir tpu-doc-build --version pr_$PR_NUMBER --version_tag_suffix "" --html --clean - mv tpu-doc-build ../optimum - cd .. + # - name: Make Habana documentation + # run: | + # sudo docker system prune -a -f + # cd optimum-habana + # make doc BUILD_DIR=habana-doc-build VERSION=pr_$PR_NUMBER + # sudo mv habana-doc-build ../optimum + # cd .. + + # - name: Make Intel documentation + # run: | + # sudo docker system prune -a -f + # cd optimum-intel + # make doc BUILD_DIR=intel-doc-build VERSION=pr_$PR_NUMBER + # sudo mv intel-doc-build ../optimum + # cd .. + + # # TODO: enable Furiosa doc build in PRs once archive.furiosa.ai is public + # - name: Make Furiosa documentation + # run: | + # echo "For PRs we don't build Furiosa doc" + + # - name: Make AMD documentation + # run: | + # sudo docker system prune -a -f + # cd optimum-amd + # make doc BUILD_DIR=amd-doc-build VERSION=pr_$PR_NUMBER + # sudo mv amd-doc-build ../optimum + # cd .. + + # - name: Make TPU documentation + # run: | + # sudo docker system prune -a -f + # source venv-doc/bin/activate + # cd optimum-tpu + # pip install -U pip + # pip install . -f https://storage.googleapis.com/libtpu-releases/index.html + # doc-builder build optimum.tpu docs/source/ --build_dir tpu-doc-build --version pr_$PR_NUMBER --version_tag_suffix "" --html --clean + # mv tpu-doc-build ../optimum + # cd .. - name: Make Optimum documentation run: | sudo docker system prune -a -f cd optimum - pip install torch -U make doc BUILD_DIR=optimum-doc-build VERSION=pr_$PR_NUMBER COMMIT_SHA_OPTIMUM=$COMMIT_SHA CLONE_URL=$PR_CLONE_URL cd .. diff --git a/docs/Dockerfile b/docs/Dockerfile index d76dc50c556..9e4a758ddfc 100644 --- a/docs/Dockerfile +++ b/docs/Dockerfile @@ -8,4 +8,4 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/doc-builder.git RUN git clone $clone_url && cd optimum && git checkout $commit_sha -RUN python3 -m pip install --no-cache-dir ./optimum[onnxruntime,benchmark,quality,exporters-tf,doc-build,diffusers] +RUN python3 -m pip install --no-cache-dir ./optimum[onnxruntime,benchmark,quality,exporters-tf,exporters-executorch,doc-build,diffusers] From e4eccbe8449d2df83905fbf9cc334756ee270b60 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 18 Dec 2024 22:15:40 +0100 Subject: [PATCH 18/30] Test doc build --- docs/Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/Dockerfile b/docs/Dockerfile index 9e4a758ddfc..d74ca571f1a 100644 --- a/docs/Dockerfile +++ b/docs/Dockerfile @@ -1,4 +1,5 @@ -FROM nikolaik/python-nodejs:python3.9-nodejs18 +# FROM nikolaik/python-nodejs:python3.9-nodejs18 +FROM nikolaik/python-nodejs:python3.11-nodejs23 ARG commit_sha ARG clone_url From e855005447d12f4567f8a14c73647b2b5464b717 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 18 Dec 2024 22:20:25 +0100 Subject: [PATCH 19/30] Test doc build --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index bb5bcc11d43..caf0ae452f6 100644 --- a/setup.py +++ b/setup.py @@ -83,7 +83,7 @@ "h5py", "numpy<1.24.0", "datasets<=2.16", - "transformers>=4.36,<4.38", + "transformers>=4.36", ], "exporters-executorch": [ "executorch>=0.4.0", From 5b32dea40fd59dd1b1843c604c694b4383748eae Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 18 Dec 2024 22:29:13 +0100 Subject: [PATCH 20/30] Test doc build --- docs/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Dockerfile b/docs/Dockerfile index d74ca571f1a..01d02908ebe 100644 --- a/docs/Dockerfile +++ b/docs/Dockerfile @@ -9,4 +9,4 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/doc-builder.git RUN git clone $clone_url && cd optimum && git checkout $commit_sha -RUN python3 -m pip install --no-cache-dir ./optimum[onnxruntime,benchmark,quality,exporters-tf,exporters-executorch,doc-build,diffusers] +RUN python3 -m pip install --no-cache-dir ./optimum[onnxruntime,benchmark,quality,exporters-tf,doc-build,diffusers] From 093e95a86e8ea322f963ba4d14a04ca55604d376 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 19 Dec 2024 10:32:01 +0100 Subject: [PATCH 21/30] Test doc build --- docs/Dockerfile | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/Dockerfile b/docs/Dockerfile index 01d02908ebe..6ffef54466f 100644 --- a/docs/Dockerfile +++ b/docs/Dockerfile @@ -10,3 +10,4 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/doc RUN git clone $clone_url && cd optimum && git checkout $commit_sha RUN python3 -m pip install --no-cache-dir ./optimum[onnxruntime,benchmark,quality,exporters-tf,doc-build,diffusers] +RUN python3 -m pip install ./optimum[exporters-executorch] From 2c748dff141a989b099cb1bd36210c5f05f41a6d Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 19 Dec 2024 10:40:35 +0100 Subject: [PATCH 22/30] Test doc build --- docs/Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/Dockerfile b/docs/Dockerfile index 6ffef54466f..be3cb05fdb5 100644 --- a/docs/Dockerfile +++ b/docs/Dockerfile @@ -1,5 +1,5 @@ -# FROM nikolaik/python-nodejs:python3.9-nodejs18 -FROM nikolaik/python-nodejs:python3.11-nodejs23 +FROM nikolaik/python-nodejs:python3.9-nodejs18 +# FROM nikolaik/python-nodejs:python3.11-nodejs23 ARG commit_sha ARG clone_url From 9309d77623af4dbb23800ebcb3186b93b976b280 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 19 Dec 2024 10:46:15 +0100 Subject: [PATCH 23/30] Build doc --- .github/workflows/build_pr_documentation.yml | 78 ++++++++++---------- docs/Dockerfile | 3 +- 2 files changed, 40 insertions(+), 41 deletions(-) diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml index 4de5ff08a56..6eb09aff304 100644 --- a/.github/workflows/build_pr_documentation.yml +++ b/.github/workflows/build_pr_documentation.yml @@ -70,45 +70,45 @@ jobs: pip install black cd .. - # - name: Make Habana documentation - # run: | - # sudo docker system prune -a -f - # cd optimum-habana - # make doc BUILD_DIR=habana-doc-build VERSION=pr_$PR_NUMBER - # sudo mv habana-doc-build ../optimum - # cd .. - - # - name: Make Intel documentation - # run: | - # sudo docker system prune -a -f - # cd optimum-intel - # make doc BUILD_DIR=intel-doc-build VERSION=pr_$PR_NUMBER - # sudo mv intel-doc-build ../optimum - # cd .. - - # # TODO: enable Furiosa doc build in PRs once archive.furiosa.ai is public - # - name: Make Furiosa documentation - # run: | - # echo "For PRs we don't build Furiosa doc" - - # - name: Make AMD documentation - # run: | - # sudo docker system prune -a -f - # cd optimum-amd - # make doc BUILD_DIR=amd-doc-build VERSION=pr_$PR_NUMBER - # sudo mv amd-doc-build ../optimum - # cd .. - - # - name: Make TPU documentation - # run: | - # sudo docker system prune -a -f - # source venv-doc/bin/activate - # cd optimum-tpu - # pip install -U pip - # pip install . -f https://storage.googleapis.com/libtpu-releases/index.html - # doc-builder build optimum.tpu docs/source/ --build_dir tpu-doc-build --version pr_$PR_NUMBER --version_tag_suffix "" --html --clean - # mv tpu-doc-build ../optimum - # cd .. + - name: Make Habana documentation + run: | + sudo docker system prune -a -f + cd optimum-habana + make doc BUILD_DIR=habana-doc-build VERSION=pr_$PR_NUMBER + sudo mv habana-doc-build ../optimum + cd .. + + - name: Make Intel documentation + run: | + sudo docker system prune -a -f + cd optimum-intel + make doc BUILD_DIR=intel-doc-build VERSION=pr_$PR_NUMBER + sudo mv intel-doc-build ../optimum + cd .. + + # TODO: enable Furiosa doc build in PRs once archive.furiosa.ai is public + - name: Make Furiosa documentation + run: | + echo "For PRs we don't build Furiosa doc" + + - name: Make AMD documentation + run: | + sudo docker system prune -a -f + cd optimum-amd + make doc BUILD_DIR=amd-doc-build VERSION=pr_$PR_NUMBER + sudo mv amd-doc-build ../optimum + cd .. + + - name: Make TPU documentation + run: | + sudo docker system prune -a -f + source venv-doc/bin/activate + cd optimum-tpu + pip install -U pip + pip install . -f https://storage.googleapis.com/libtpu-releases/index.html + doc-builder build optimum.tpu docs/source/ --build_dir tpu-doc-build --version pr_$PR_NUMBER --version_tag_suffix "" --html --clean + mv tpu-doc-build ../optimum + cd .. - name: Make Optimum documentation run: | diff --git a/docs/Dockerfile b/docs/Dockerfile index be3cb05fdb5..d8bd060a46c 100644 --- a/docs/Dockerfile +++ b/docs/Dockerfile @@ -1,5 +1,4 @@ -FROM nikolaik/python-nodejs:python3.9-nodejs18 -# FROM nikolaik/python-nodejs:python3.11-nodejs23 +FROM nikolaik/python-nodejs:python3.11-nodejs23 ARG commit_sha ARG clone_url From b447d1a013f2f862eace34e5613552853402d2d8 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 19 Dec 2024 12:10:33 +0100 Subject: [PATCH 24/30] Build doc --- .github/workflows/build_pr_documentation.yml | 78 ++++++++++---------- docs/Dockerfile | 3 +- 2 files changed, 40 insertions(+), 41 deletions(-) diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml index 6eb09aff304..4de5ff08a56 100644 --- a/.github/workflows/build_pr_documentation.yml +++ b/.github/workflows/build_pr_documentation.yml @@ -70,45 +70,45 @@ jobs: pip install black cd .. - - name: Make Habana documentation - run: | - sudo docker system prune -a -f - cd optimum-habana - make doc BUILD_DIR=habana-doc-build VERSION=pr_$PR_NUMBER - sudo mv habana-doc-build ../optimum - cd .. - - - name: Make Intel documentation - run: | - sudo docker system prune -a -f - cd optimum-intel - make doc BUILD_DIR=intel-doc-build VERSION=pr_$PR_NUMBER - sudo mv intel-doc-build ../optimum - cd .. - - # TODO: enable Furiosa doc build in PRs once archive.furiosa.ai is public - - name: Make Furiosa documentation - run: | - echo "For PRs we don't build Furiosa doc" - - - name: Make AMD documentation - run: | - sudo docker system prune -a -f - cd optimum-amd - make doc BUILD_DIR=amd-doc-build VERSION=pr_$PR_NUMBER - sudo mv amd-doc-build ../optimum - cd .. - - - name: Make TPU documentation - run: | - sudo docker system prune -a -f - source venv-doc/bin/activate - cd optimum-tpu - pip install -U pip - pip install . -f https://storage.googleapis.com/libtpu-releases/index.html - doc-builder build optimum.tpu docs/source/ --build_dir tpu-doc-build --version pr_$PR_NUMBER --version_tag_suffix "" --html --clean - mv tpu-doc-build ../optimum - cd .. + # - name: Make Habana documentation + # run: | + # sudo docker system prune -a -f + # cd optimum-habana + # make doc BUILD_DIR=habana-doc-build VERSION=pr_$PR_NUMBER + # sudo mv habana-doc-build ../optimum + # cd .. + + # - name: Make Intel documentation + # run: | + # sudo docker system prune -a -f + # cd optimum-intel + # make doc BUILD_DIR=intel-doc-build VERSION=pr_$PR_NUMBER + # sudo mv intel-doc-build ../optimum + # cd .. + + # # TODO: enable Furiosa doc build in PRs once archive.furiosa.ai is public + # - name: Make Furiosa documentation + # run: | + # echo "For PRs we don't build Furiosa doc" + + # - name: Make AMD documentation + # run: | + # sudo docker system prune -a -f + # cd optimum-amd + # make doc BUILD_DIR=amd-doc-build VERSION=pr_$PR_NUMBER + # sudo mv amd-doc-build ../optimum + # cd .. + + # - name: Make TPU documentation + # run: | + # sudo docker system prune -a -f + # source venv-doc/bin/activate + # cd optimum-tpu + # pip install -U pip + # pip install . -f https://storage.googleapis.com/libtpu-releases/index.html + # doc-builder build optimum.tpu docs/source/ --build_dir tpu-doc-build --version pr_$PR_NUMBER --version_tag_suffix "" --html --clean + # mv tpu-doc-build ../optimum + # cd .. - name: Make Optimum documentation run: | diff --git a/docs/Dockerfile b/docs/Dockerfile index d8bd060a46c..13ad3258b3d 100644 --- a/docs/Dockerfile +++ b/docs/Dockerfile @@ -8,5 +8,4 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/doc-builder.git RUN git clone $clone_url && cd optimum && git checkout $commit_sha -RUN python3 -m pip install --no-cache-dir ./optimum[onnxruntime,benchmark,quality,exporters-tf,doc-build,diffusers] -RUN python3 -m pip install ./optimum[exporters-executorch] +RUN python3 -m pip install --no-cache-dir ./optimum[onnxruntime,benchmark,quality,exporters-tf,exporters-executorch,doc-build,diffusers] --index-url https://download.pytorch.org/whl/cpu From b8701ff913213d7f19ec1fc86c28651b3fef4f0b Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 19 Dec 2024 12:14:53 +0100 Subject: [PATCH 25/30] Build doc --- docs/Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/Dockerfile b/docs/Dockerfile index 13ad3258b3d..4eb2ad476c8 100644 --- a/docs/Dockerfile +++ b/docs/Dockerfile @@ -8,4 +8,5 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/doc-builder.git RUN git clone $clone_url && cd optimum && git checkout $commit_sha -RUN python3 -m pip install --no-cache-dir ./optimum[onnxruntime,benchmark,quality,exporters-tf,exporters-executorch,doc-build,diffusers] --index-url https://download.pytorch.org/whl/cpu +RUN python3 -m pip install --no-cache-dir ./optimum[onnxruntime,benchmark,quality,exporters-tf,doc-build,diffusers] +RUN python3 -m pip install ./optimum[exporters-executorch] --index-url https://download.pytorch.org/whl/cpu From 3041f169b39b7e8871208963f933ceaab22a3014 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 19 Dec 2024 12:22:28 +0100 Subject: [PATCH 26/30] Build doc --- docs/Dockerfile | 3 +-- setup.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/docs/Dockerfile b/docs/Dockerfile index 4eb2ad476c8..f168d7ca8a6 100644 --- a/docs/Dockerfile +++ b/docs/Dockerfile @@ -8,5 +8,4 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/doc-builder.git RUN git clone $clone_url && cd optimum && git checkout $commit_sha -RUN python3 -m pip install --no-cache-dir ./optimum[onnxruntime,benchmark,quality,exporters-tf,doc-build,diffusers] -RUN python3 -m pip install ./optimum[exporters-executorch] --index-url https://download.pytorch.org/whl/cpu +RUN python3 -m pip install --no-cache-dir ./optimum[onnxruntime,benchmark,quality,exporters-tf,exporters-executorch,doc-build,diffusers] diff --git a/setup.py b/setup.py index caf0ae452f6..f047b7f4e08 100644 --- a/setup.py +++ b/setup.py @@ -76,14 +76,14 @@ ], "exporters-tf": [ "tensorflow>=2.4,<=2.12.1", - "tf2onnx", - "onnx", - "onnxruntime", - "timm", - "h5py", - "numpy<1.24.0", - "datasets<=2.16", - "transformers>=4.36", + # "tf2onnx", + # "onnx", + # "onnxruntime", + # "timm", + # "h5py", + # "numpy<1.24.0", + # "datasets<=2.16", + # "transformers>=4.36", ], "exporters-executorch": [ "executorch>=0.4.0", From 045df3795e6aa825894da2412ef1e1ee1ce5ca74 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 19 Dec 2024 12:51:38 +0100 Subject: [PATCH 27/30] Build doc --- setup.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index f047b7f4e08..4ec891d7c19 100644 --- a/setup.py +++ b/setup.py @@ -75,15 +75,15 @@ "transformers>=4.36,<4.47.0", ], "exporters-tf": [ - "tensorflow>=2.4,<=2.12.1", - # "tf2onnx", - # "onnx", - # "onnxruntime", - # "timm", - # "h5py", - # "numpy<1.24.0", - # "datasets<=2.16", - # "transformers>=4.36", + # "tensorflow>=2.4,<=2.12.1", + "tf2onnx", + "onnx", + "onnxruntime", + "timm", + "h5py", + "numpy<1.24.0", + "datasets<=2.16", + "transformers>=4.36", ], "exporters-executorch": [ "executorch>=0.4.0", From 224101bdcadc384504aa17b4c9d731f42cac3079 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 20 Dec 2024 09:55:33 +0100 Subject: [PATCH 28/30] Try with the full doc --- .github/workflows/build_pr_documentation.yml | 78 ++++++++++---------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml index 4de5ff08a56..6eb09aff304 100644 --- a/.github/workflows/build_pr_documentation.yml +++ b/.github/workflows/build_pr_documentation.yml @@ -70,45 +70,45 @@ jobs: pip install black cd .. - # - name: Make Habana documentation - # run: | - # sudo docker system prune -a -f - # cd optimum-habana - # make doc BUILD_DIR=habana-doc-build VERSION=pr_$PR_NUMBER - # sudo mv habana-doc-build ../optimum - # cd .. - - # - name: Make Intel documentation - # run: | - # sudo docker system prune -a -f - # cd optimum-intel - # make doc BUILD_DIR=intel-doc-build VERSION=pr_$PR_NUMBER - # sudo mv intel-doc-build ../optimum - # cd .. - - # # TODO: enable Furiosa doc build in PRs once archive.furiosa.ai is public - # - name: Make Furiosa documentation - # run: | - # echo "For PRs we don't build Furiosa doc" - - # - name: Make AMD documentation - # run: | - # sudo docker system prune -a -f - # cd optimum-amd - # make doc BUILD_DIR=amd-doc-build VERSION=pr_$PR_NUMBER - # sudo mv amd-doc-build ../optimum - # cd .. - - # - name: Make TPU documentation - # run: | - # sudo docker system prune -a -f - # source venv-doc/bin/activate - # cd optimum-tpu - # pip install -U pip - # pip install . -f https://storage.googleapis.com/libtpu-releases/index.html - # doc-builder build optimum.tpu docs/source/ --build_dir tpu-doc-build --version pr_$PR_NUMBER --version_tag_suffix "" --html --clean - # mv tpu-doc-build ../optimum - # cd .. + - name: Make Habana documentation + run: | + sudo docker system prune -a -f + cd optimum-habana + make doc BUILD_DIR=habana-doc-build VERSION=pr_$PR_NUMBER + sudo mv habana-doc-build ../optimum + cd .. + + - name: Make Intel documentation + run: | + sudo docker system prune -a -f + cd optimum-intel + make doc BUILD_DIR=intel-doc-build VERSION=pr_$PR_NUMBER + sudo mv intel-doc-build ../optimum + cd .. + + # TODO: enable Furiosa doc build in PRs once archive.furiosa.ai is public + - name: Make Furiosa documentation + run: | + echo "For PRs we don't build Furiosa doc" + + - name: Make AMD documentation + run: | + sudo docker system prune -a -f + cd optimum-amd + make doc BUILD_DIR=amd-doc-build VERSION=pr_$PR_NUMBER + sudo mv amd-doc-build ../optimum + cd .. + + - name: Make TPU documentation + run: | + sudo docker system prune -a -f + source venv-doc/bin/activate + cd optimum-tpu + pip install -U pip + pip install . -f https://storage.googleapis.com/libtpu-releases/index.html + doc-builder build optimum.tpu docs/source/ --build_dir tpu-doc-build --version pr_$PR_NUMBER --version_tag_suffix "" --html --clean + mv tpu-doc-build ../optimum + cd .. - name: Make Optimum documentation run: | From b9ce05da6d3cccac1a7ec10ed702a27f74c6a31c Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 20 Dec 2024 10:21:45 +0100 Subject: [PATCH 29/30] Try with the full doc --- docs/Dockerfile | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/Dockerfile b/docs/Dockerfile index f168d7ca8a6..5181177f0db 100644 --- a/docs/Dockerfile +++ b/docs/Dockerfile @@ -8,4 +8,4 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/doc-builder.git RUN git clone $clone_url && cd optimum && git checkout $commit_sha -RUN python3 -m pip install --no-cache-dir ./optimum[onnxruntime,benchmark,quality,exporters-tf,exporters-executorch,doc-build,diffusers] +RUN python3 -m pip install --no-cache-dir ./optimum[onnxruntime,benchmark,quality,exporters-executorch,doc-build,diffusers] diff --git a/setup.py b/setup.py index 4ec891d7c19..caf0ae452f6 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ "transformers>=4.36,<4.47.0", ], "exporters-tf": [ - # "tensorflow>=2.4,<=2.12.1", + "tensorflow>=2.4,<=2.12.1", "tf2onnx", "onnx", "onnxruntime", From e345f0745b986561ae4998b9407aff93f93600e1 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 20 Dec 2024 11:56:56 +0100 Subject: [PATCH 30/30] Restore setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index caf0ae452f6..bb5bcc11d43 100644 --- a/setup.py +++ b/setup.py @@ -83,7 +83,7 @@ "h5py", "numpy<1.24.0", "datasets<=2.16", - "transformers>=4.36", + "transformers>=4.36,<4.38", ], "exporters-executorch": [ "executorch>=0.4.0",