forked from dottxt-ai/outlines
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce mlx-lm model via outlines.models.mlxlm
- Loading branch information
Showing
7 changed files
with
192 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from typing import TYPE_CHECKING, Optional, Tuple | ||
|
||
from .transformers import TransformerTokenizer | ||
|
||
if TYPE_CHECKING: | ||
import mlx.core as mx | ||
import mlx.nn as nn | ||
from mlx_lm.tokenizer_utils import TokenizerWrapper | ||
|
||
|
||
class MLXLM: | ||
""" | ||
Represents an `mlx_lm` model | ||
Adapted from | ||
https://github.com/sacha-ichbiah/outlines-mlx/blob/main/outlinesmlx/models/mlx.py | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: "nn.Module", | ||
tokenizer: "TokenizerWrapper", | ||
): | ||
self.model = model | ||
|
||
# mlx's TokenizerWrapper = HF tokenizer, `_tokenizer`, + a `_detokenizer` | ||
self.tokenizer = TransformerTokenizer(tokenizer._tokenizer) | ||
|
||
def forward( | ||
self, | ||
input_ids: "mx.array", | ||
attention_mask: "mx.array", | ||
past_key_values: "mx.array", | ||
) -> Tuple["mx.array", Optional["mx.array"]]: | ||
"""Compute a forward pass through the transformer model. | ||
Parameters | ||
---------- | ||
input_ids | ||
The input token ids. Must be one or two dimensional. | ||
attention_mask | ||
The attention mask. Must be one or two dimensional. | ||
past_key_values | ||
A tuple of tuples containing the cached key and value tensors for each | ||
attention head. | ||
Returns | ||
------- | ||
The computed logits and the new cached key and value tensors. | ||
""" | ||
assert 0 < input_ids.ndim < 3 | ||
|
||
if past_key_values: | ||
input_ids = input_ids[..., -1][..., None] | ||
|
||
logits, kv_cache = self.model( | ||
input_ids, | ||
cache=past_key_values, | ||
) | ||
|
||
return logits, kv_cache | ||
|
||
def __call__( | ||
self, | ||
input_ids: "mx.array", | ||
attention_mask: "mx.array", | ||
past_key_values: Optional["mx.array"] = None, | ||
) -> Tuple["mx.array", "mx.array"]: | ||
logits, kv_cache = self.forward(input_ids, None, past_key_values) | ||
next_token_logits = logits[..., -1, :] | ||
|
||
return next_token_logits, kv_cache | ||
|
||
|
||
def mlxlm( | ||
model_name: str, | ||
tokenizer_config: dict = {}, | ||
# TODO: include these kwargs when mlx-lm has new release | ||
# model_config: dict = {}, | ||
# adapter_path: Optional[str] = None, | ||
# lazy: bool = False, | ||
): | ||
"""Instantiate a model from the `mlx_lm` library and its tokenizer. | ||
Signature adapted from | ||
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L422 | ||
Parameters | ||
---------- | ||
Args: | ||
path_or_hf_repo (Path): The path or the huggingface repository to load the model from. | ||
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. | ||
Defaults to an empty dictionary. | ||
model_config(dict, optional): Configuration parameters specifically for the model. | ||
Defaults to an empty dictionary. | ||
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers | ||
to the model. Default: ``None``. | ||
lazy (bool): If False eval the model parameters to make sure they are | ||
loaded in memory before returning, otherwise they will be loaded | ||
when needed. Default: ``False`` | ||
Returns | ||
------- | ||
A `MLXLM` model instance. | ||
""" | ||
try: | ||
import mlx_lm | ||
except ImportError: | ||
raise ImportError( | ||
"The `mlx_lm` library needs to be installed in order to use `mlx_lm` models." | ||
) | ||
|
||
model, tokenizer = mlx_lm.load( | ||
model_name, | ||
tokenizer_config=tokenizer_config, | ||
# model_config=model_config, | ||
# adapter_path=adapter_path, | ||
# lazy=lazy, | ||
) | ||
return MLXLM(model, tokenizer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import re | ||
|
||
import pytest | ||
|
||
import outlines.generate as generate | ||
import outlines.models as models | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def model_llamacpp(tmp_path_factory): | ||
return models.llamacpp( | ||
repo_id="M4-ai/TinyMistral-248M-v2-Instruct-GGUF", | ||
filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf", | ||
) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def model_mlxlm(tmp_path_factory): | ||
return models.mlxlm("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit") | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def model_transformers(tmp_path_factory): | ||
return models.transformers("Locutusque/TinyMistral-248M-v2-Instruct", device="cpu") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"model_fixture", | ||
("model_llamacpp", "model_mlxlm", "model_transformers"), | ||
) | ||
def test_generate_text(request, model_fixture): | ||
model = request.getfixturevalue(model_fixture) | ||
generator = generate.text(model) | ||
res = generator("test", max_tokens=10) | ||
assert isinstance(res, str) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"model_fixture", | ||
("model_llamacpp", "model_mlxlm", "model_transformers"), | ||
) | ||
@pytest.mark.parametrize( | ||
"pattern", | ||
( | ||
"[0-9]", | ||
"abc*", | ||
"\\+?[1-9][0-9]{7,14}", | ||
r"(https?:\/\/)?([\da-z\.-]+)\.([a-z\.]{2,6})([\/\w \.-]*)*\/?", | ||
), | ||
) | ||
def test_generate_json(request, model_fixture, pattern): | ||
model = request.getfixturevalue(model_fixture) | ||
generator = generate.text(model) | ||
res = generator("foobarbaz", max_tokens=20) | ||
assert re.match(pattern, res) is not None, res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters