-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Export to ExecuTorch: Initial Integration
- Loading branch information
Guang Yang
committed
Nov 14, 2024
1 parent
7e8d857
commit 757f152
Showing
16 changed files
with
716 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
"""Defines the command line for the export with ExecuTorch.""" | ||
|
||
from pathlib import Path | ||
from typing import TYPE_CHECKING | ||
|
||
from ...exporters import TasksManager | ||
from ..base import BaseOptimumCLICommand | ||
|
||
|
||
if TYPE_CHECKING: | ||
from argparse import ArgumentParser | ||
|
||
|
||
def parse_args_executorch(parser): | ||
required_group = parser.add_argument_group("Required arguments") | ||
required_group.add_argument( | ||
"-m", "--model", type=str, required=True, help="Model ID on huggingface.co or path on disk to load model from." | ||
) | ||
required_group.add_argument( | ||
"--output_dir", type=Path, help="Path indicating the directory where to store the generated ExecuTorch model." | ||
) | ||
|
||
optional_group = parser.add_argument_group("Optional arguments") | ||
optional_group.add_argument( | ||
"--task", | ||
default="auto", | ||
help=( | ||
"The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among:" | ||
f" {str(TasksManager.get_all_tasks())}. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder." | ||
), | ||
) | ||
optional_group.add_argument( | ||
"--recipe", | ||
type=str, | ||
default="xnnpack", | ||
help='Pre-defined recipes for export to ExecuTorch. Defaults to "xnnpack".', | ||
) | ||
|
||
|
||
class ExecuTorchExportCommand(BaseOptimumCLICommand): | ||
@staticmethod | ||
def parse_args(parser: "ArgumentParser"): | ||
return parse_args_executorch(parser) | ||
|
||
def run(self): | ||
from ...exporters.executorch import main_export | ||
|
||
main_export( | ||
model_name_or_path=self.args.model, | ||
task=self.args.task, | ||
recipe=self.args.recipe, | ||
output_dir=self.args.output_dir, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from typing import TYPE_CHECKING | ||
from transformers.utils import _LazyModule | ||
|
||
|
||
_import_structure = { | ||
"modeling_executorch": [ | ||
"ExecuTorchModelForCausalLM", | ||
], | ||
} | ||
|
||
if TYPE_CHECKING: | ||
from .modeling_executorch import ExecuTorchModelForCausalLM | ||
else: | ||
import sys | ||
|
||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
"""ExecuTorchModelForXXX classes, allowing to run ExecuTorch Models with ExecuTorch Runtime using the same API as Transformers.""" | ||
|
||
import logging | ||
import os | ||
import warnings | ||
from pathlib import Path | ||
from tempfile import TemporaryDirectory | ||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union | ||
|
||
import torch | ||
from executorch.extension.pybindings.portable_lib import _load_for_executorch | ||
from huggingface_hub import hf_hub_download | ||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE | ||
from huggingface_hub.utils import EntryNotFoundError | ||
from transformers import ( | ||
AutoConfig, | ||
AutoModel, | ||
GenerationMixin, | ||
AutoModelForCausalLM, | ||
GenerationConfig, | ||
) | ||
from transformers.integrations.executorch import TorchExportableModuleWithStaticCache | ||
from transformers.modeling_outputs import ( | ||
BaseModelOutput, | ||
CausalLMOutput, | ||
CausalLMOutputWithPast, | ||
ModelOutput, | ||
) | ||
|
||
from ..exporters import TasksManager | ||
from ..exporters.executorch import main_export | ||
from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel | ||
|
||
if TYPE_CHECKING: | ||
from transformers import PretrainedConfig | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ExecuTorchModelForCausalLM(OptimizedModel): | ||
""" | ||
ExecuTorch model with a causal language modeling head for ExecuTorch Runtime inference. | ||
""" | ||
|
||
auto_model_class = AutoModelForCausalLM | ||
|
||
def __init__( | ||
self, | ||
model: "ExecuTorchModule", | ||
config: "PretrainedConfig", | ||
): | ||
super().__init__(model, config) | ||
self.et_model = model | ||
print(f"DEBUG all static methods: {self.et_model.method_names()}") | ||
self.use_kv_cache = self.et_model.run_method("use_kv_cache")[0] | ||
self.max_seq_len = self.et_model.run_method("get_max_seq_len")[0] | ||
self.max_batch_size = self.et_model.run_method("get_max_batch_size")[0] | ||
self.dtype = self.et_model.run_method("get_dtype")[0] | ||
self.bos_token_id = self.et_model.run_method("get_bos_id")[0] | ||
self.eos_token_id = self.et_model.run_method("get_eos_id")[0] | ||
self.vocab_size = self.et_model.run_method("get_vocab_size")[0] | ||
|
||
def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor) -> torch.Tensor: | ||
return self.et_model.forward((input_ids, cache_position))[0] | ||
|
||
@classmethod | ||
def from_pretrained( | ||
cls, | ||
model_dir_path: Union[str, Path], | ||
task: str, | ||
recipe: str, | ||
config: "PretrainedConfig" = None, | ||
use_auth_token: Optional[Union[bool, str]] = None, | ||
token: Optional[Union[bool, str]] = None, | ||
revision: Optional[str] = None, | ||
force_download: bool = False, | ||
cache_dir: str = HUGGINGFACE_HUB_CACHE, | ||
subfolder: str = "", | ||
local_files_only: bool = False, | ||
) -> "ExecuTorchModelForCausalLM": | ||
if use_auth_token is not None: | ||
warnings.warn( | ||
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", | ||
FutureWarning, | ||
) | ||
if token is not None: | ||
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") | ||
token = use_auth_token | ||
|
||
full_path = os.path.join(f"{model_dir_path}", "model.pte") | ||
model = _load_for_executorch(full_path) | ||
logging.debug(f"{model.method_meta('forward')}") | ||
return cls( | ||
model=model, | ||
config=config, | ||
) | ||
|
||
def _save_pretrained(self, save_directory): | ||
""" | ||
Saves a model weights into a directory, so that it can be re-loaded using the | ||
[`from_pretrained`] class method. | ||
""" | ||
raise NotImplementedError | ||
|
||
@classmethod | ||
def _export( | ||
cls, | ||
model_id: str, | ||
task: str, | ||
recipe: str, | ||
config: "PretrainedConfig", | ||
use_auth_token: Optional[Union[bool, str]] = None, | ||
token: Optional[Union[bool, str]] = None, | ||
revision: Optional[str] = None, | ||
force_download: bool = False, | ||
cache_dir: str = HUGGINGFACE_HUB_CACHE, | ||
subfolder: str = "", | ||
local_files_only: bool = False, | ||
trust_remote_code: bool = False, | ||
): | ||
if use_auth_token is not None: | ||
warnings.warn( | ||
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", | ||
FutureWarning, | ||
) | ||
if token is not None: | ||
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") | ||
token = use_auth_token | ||
|
||
save_dir = TemporaryDirectory() | ||
save_dir_path = Path(save_dir.name) | ||
|
||
# Export to ExecuTorch and save the pte file to the temporary directory | ||
main_export( | ||
model_name_or_path=model_id, | ||
output=save_dir_path, | ||
task=task, | ||
recipe=recipe, | ||
subfolder=subfolder, | ||
revision=revision, | ||
cache_dir=cache_dir, | ||
token=token, | ||
local_files_only=local_files_only, | ||
force_download=force_download, | ||
trust_remote_code=trust_remote_code, | ||
) | ||
|
||
return cls._from_pretrained( | ||
model_dir_path=save_dir_path, | ||
task=task, | ||
recipe=recipe, | ||
config=config, | ||
use_auth_token=use_auth_token, | ||
subfolder=subfolder, | ||
revision=revision, | ||
cache_dir=cache_dir, | ||
token=token, | ||
local_files_only=local_files_only, | ||
force_download=force_download, | ||
) | ||
|
||
def generate( | ||
self, | ||
prompt_tokens: List[int], | ||
echo: bool = False, | ||
pos_base: int = 0, | ||
) -> List[int]: | ||
|
||
self.device = torch.device("cpu") | ||
self.max_seq_len = 256 | ||
generated_tokens = [] | ||
|
||
# prefill | ||
for i, prompt_token in enumerate(prompt_tokens): | ||
logits = self.forward( | ||
input_ids=torch.tensor([prompt_token], dtype=torch.long, device=self.device).unsqueeze(0), | ||
cache_position=torch.tensor([i], dtype=torch.long, device=self.device), | ||
) | ||
|
||
next_token = torch.argmax(logits, dim=-1).item() | ||
generated_tokens = prompt_tokens + [next_token] | ||
|
||
while len(generated_tokens) < self.max_seq_len: | ||
logits = self.forward( | ||
input_ids=torch.tensor([next_token], dtype=torch.long, device=self.device).unsqueeze(0), | ||
cache_position=torch.tensor( | ||
[pos_base + len(generated_tokens) - 1], | ||
dtype=torch.long, | ||
device=self.device, | ||
), | ||
) | ||
next_token = torch.argmax(logits, dim=-1).item() | ||
generated_tokens.append(next_token) | ||
if next_token == self.eos_token_id: | ||
break | ||
|
||
return generated_tokens if echo else generated_tokens[len(prompt_tokens) :] | ||
|
||
def text_generation( | ||
self, | ||
tokenizer: "PreTrainedTokenizer", | ||
prompt: str, | ||
echo: bool = True, | ||
) -> List[int]: | ||
""" | ||
Perform text completion for a prompt using the language model. | ||
Args: | ||
prompt (str): Text prompt for completion. | ||
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. | ||
Returns: | ||
Generated list of tokens. | ||
Note: | ||
This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness. | ||
""" | ||
self.tokenizer = tokenizer | ||
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id: | ||
raise ValueError( | ||
f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}." | ||
) | ||
if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id != self.eos_token_id: | ||
raise ValueError( | ||
f"The tokenizer's eos_token_id={self.tokenizer.eos_token_id} must be the same as the model's eos_token_id={self.eos_token_id}." | ||
) | ||
|
||
prompt_tokens = self.tokenizer.encode(prompt) | ||
generated_tokens = self.generate( | ||
prompt_tokens=prompt_tokens, | ||
echo=echo, | ||
) | ||
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from typing import TYPE_CHECKING | ||
|
||
from transformers.utils import _LazyModule | ||
|
||
|
||
_import_structure = { | ||
"convert": [ | ||
"export_to_executorch", | ||
], | ||
"__main__": ["main_export"], | ||
} | ||
|
||
if TYPE_CHECKING: | ||
from .__main__ import main_export | ||
from .convert import export_to_executorch | ||
else: | ||
import sys | ||
|
||
sys.modules[__name__] = _LazyModule( | ||
__name__, | ||
globals()["__file__"], | ||
_import_structure, | ||
module_spec=__spec__, | ||
) |
Oops, something went wrong.