Skip to content

Commit

Permalink
[LLAMA] Export Pathway + clean-up other imports/pathways (#279)
Browse files Browse the repository at this point in the history
* add llama export pathway

* update/clean-up paths for finetune and llama, fix import check, fix llama export script typo

* fix help description

* address PR comments
  • Loading branch information
dsikka authored Sep 15, 2023
1 parent 7d5b20d commit 4dc749d
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 16 deletions.
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _setup_install_requires() -> List:


def _setup_extras() -> Dict:
return {"dev": _dev_deps, "_nm_deps": _nm_deps, "llm": _llm_deps}
return {"dev": _dev_deps, "nm": _nm_deps, "llm": _llm_deps}


def _setup_entry_points() -> Dict:
Expand All @@ -86,7 +86,8 @@ def _setup_entry_points() -> Dict:
"sparsify.run=sparsify.cli.run:main",
"sparsify.login=sparsify.login:main",
"sparsify.check_environment=sparsify.check_environment.main:main",
"finetune=sparsify.auto.tasks.finetune.finetune:parse_args_and_run",
"sparsify.llm_finetune=sparsify.auto.tasks.finetune.finetune:parse_args_and_run", # noqa E501
"sparisfy.llama_export=sparsify.auto.tasks.transformers.llama:llama_export",
]
}

Expand Down
19 changes: 5 additions & 14 deletions src/sparsify/auto/tasks/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,8 @@
# flake8: noqa
# isort: skip_file


def _check_nm_install():
try:
from .runner import *
except ImportError as exception:
raise ImportError(
"Please install sparsify[nm] to use this pathway."
) from exception


_check_nm_install()

from .args import *
from .runner import *
try:
from .args import *
from .runner import *
except ImportError as exception:
raise ImportError("Please install sparsify[nm] to use this pathway.") from exception
87 changes: 87 additions & 0 deletions src/sparsify/auto/tasks/transformers/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import click
from sparseml.transformers.export import export as export_hook
from sparsify.auto.tasks.transformers import TransformersExportArgs


_LOGGER = logging.getLogger()
_LOGGER.setLevel(logging.INFO)

"""
Usage: sparisfy.llama_export [OPTIONS]
Exports a LLAMA model given a path to model directory, sequence length, and
name of the exported pathway.
Example: sparisfy.llama_export --model_path <MODEL_PATH>
--sequence_length <INT>
Output: Produces a deployment directory with the exported model
<ONNX_FILE_NAME>
Options:
--model_path TEXT Path to directory where model files for weights,
config, and tokenizer are stored
--sequence_length INTEGER Sequence length to use. [default: 2048]
--onnx_file_name TEXT Name of the exported model. [default:
model.onnx]
--help Show this message and exit. [default: False]
"""


@click.command(context_settings={"show_default": True})
@click.option(
"--model_path",
default=None,
type=str,
help="Path to directory where model files for weights, config, and "
"tokenizer are stored",
)
@click.option(
"--sequence_length",
default=2048,
type=int,
help="Sequence length to use.",
)
@click.option(
"--onnx_file_name",
default="model.onnx",
type=str,
help="Name of the exported model.",
)
def llama_export(model_path: str, sequence_length: int, onnx_file_name: str):
"""
Exports a LLAMA model given a path to model directory, sequence length, and name
of the exported pathway.
Example:
sparisfy.llama_export --model_path <MODEL_PATH> --sequence_length <INT>
Output:
Produces a deployment directory with the exported model <ONNX_FILE_NAME>
"""
export_args = TransformersExportArgs(
task="text-generation",
model_path=model_path,
sequence_length=sequence_length,
no_convert_qat=True,
onnx_file_name=onnx_file_name,
)
_LOGGER.info("Exporting LLAMA model")
export_hook(**export_args.dict())

0 comments on commit 4dc749d

Please sign in to comment.