Skip to content

Commit

Permalink
Introduce mlx-lm model via outlines.models.mlxlm
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Jun 6, 2024
1 parent 8a6b7dc commit 346e383
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 47 deletions.
22 changes: 11 additions & 11 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def _walk_fsm(
fsm_transitions: Dict[Tuple[int, int], int],
fsm_initial: int,
fsm_finals: Set[int],
token_trans_key_seq: Sequence[int],
token_transition_keys: Sequence[int],
start_state: int,
full_match: bool = True,
) -> List[int]:
Expand All @@ -424,7 +424,7 @@ def _walk_fsm(

# Iterate over token transition key sequence. The transition key
# sequence represents the FSM traversal rules of the tokens symbols.
for i, trans_key in enumerate(token_trans_key_seq):
for i, trans_key in enumerate(token_transition_keys):
new_state = fsm_transitions.get((state, trans_key))

if new_state is None:
Expand All @@ -448,7 +448,7 @@ def _walk_fsm(

def walk_fsm(
fsm: BetterFSM,
token_trans_key_seq: Sequence[int],
token_transition_keys: Sequence[int],
start_state: int,
full_match: bool = True,
) -> List[int]:
Expand All @@ -462,7 +462,7 @@ def walk_fsm(

# Iterate over token transition key sequence. The transition key
# sequence represents the FSM traversal rules of the tokens symbols.
for i, trans_key in enumerate(token_trans_key_seq):
for i, trans_key in enumerate(token_transition_keys):
new_state = fsm_transitions.get((state, trans_key))

if new_state is None:
Expand Down Expand Up @@ -703,10 +703,10 @@ def get_token_transition_keys(
alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
)

tok_trans_array = np.empty(len(token_transition_keys), dtype=np.int64)
token_transition_keys_array = np.empty(len(token_transition_keys), dtype=np.int64)
for j in range(len(token_transition_keys)):
tok_trans_array[j] = token_transition_keys[j]
return tok_trans_array
token_transition_keys_array[j] = token_transition_keys[j]
return token_transition_keys_array


@numba.njit(cache=True, nogil=True)
Expand All @@ -718,14 +718,14 @@ def get_vocabulary_transition_keys(
"""
Calculate the sequence transition keys for each token str within a vocabulary
"""
tokens_trans_keys = numba.typed.List.empty_list(numba.int64[:])
vocab_transition_keys = numba.typed.List.empty_list(numba.int64[:])
for token_str, _ in vocabulary:
trans_key_seq_array = get_token_transition_keys(
token_transition_keys = get_token_transition_keys(
alphabet_symbol_mapping, alphabet_anything_value, token_str
)
tokens_trans_keys.append(trans_key_seq_array)
vocab_transition_keys.append(token_transition_keys)

return tokens_trans_keys
return vocab_transition_keys


def create_fsm_index_end_to_end(
Expand Down
1 change: 1 addition & 0 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .exllamav2 import ExLlamaV2Model, exl2
from .llamacpp import LlamaCpp, llamacpp
from .mamba import Mamba, mamba
from .mlxlm import MLXLM, mlxlm
from .openai import OpenAI, azure_openai, openai
from .transformers import Transformers, transformers
from .vllm import VLLM, vllm
Expand Down
122 changes: 122 additions & 0 deletions outlines/models/mlxlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from typing import TYPE_CHECKING, Optional, Tuple

from .transformers import TransformerTokenizer

if TYPE_CHECKING:
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.tokenizer_utils import TokenizerWrapper


class MLXLM:
"""
Represents an `mlx_lm` model
Adapted from
https://github.com/sacha-ichbiah/outlines-mlx/blob/main/outlinesmlx/models/mlx.py
"""

def __init__(
self,
model: "nn.Module",
tokenizer: "TokenizerWrapper",
):
self.model = model

# mlx's TokenizerWrapper = HF tokenizer, `_tokenizer`, + a `_detokenizer`
self.tokenizer = TransformerTokenizer(tokenizer._tokenizer)

def forward(
self,
input_ids: "mx.array",
attention_mask: "mx.array",
past_key_values: "mx.array",
) -> Tuple["mx.array", Optional["mx.array"]]:
"""Compute a forward pass through the transformer model.
Parameters
----------
input_ids
The input token ids. Must be one or two dimensional.
attention_mask
The attention mask. Must be one or two dimensional.
past_key_values
A tuple of tuples containing the cached key and value tensors for each
attention head.
Returns
-------
The computed logits and the new cached key and value tensors.
"""
assert 0 < input_ids.ndim < 3

if past_key_values:
input_ids = input_ids[..., -1][..., None]

logits, kv_cache = self.model(
input_ids,
cache=past_key_values,
)

return logits, kv_cache

def __call__(
self,
input_ids: "mx.array",
attention_mask: "mx.array",
past_key_values: Optional["mx.array"] = None,
) -> Tuple["mx.array", "mx.array"]:
logits, kv_cache = self.forward(input_ids, None, past_key_values)
next_token_logits = logits[..., -1, :]

return next_token_logits, kv_cache


def mlxlm(
model_name: str,
tokenizer_config: dict = {},
# TODO: include these kwargs when mlx-lm has new release
# model_config: dict = {},
# adapter_path: Optional[str] = None,
# lazy: bool = False,
):
"""Instantiate a model from the `mlx_lm` library and its tokenizer.
Signature adapted from
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L422
Parameters
----------
Args:
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
model_config(dict, optional): Configuration parameters specifically for the model.
Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
Returns
-------
A `MLXLM` model instance.
"""
try:
import mlx_lm
except ImportError:
raise ImportError(
"The `mlx_lm` library needs to be installed in order to use `mlx_lm` models."
)

model, tokenizer = mlx_lm.load(
model_name,
tokenizer_config=tokenizer_config,
# model_config=model_config,
# adapter_path=adapter_path,
# lazy=lazy,
)
return MLXLM(model, tokenizer)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ test = [
"beartype<0.16.0",
"responses",
"llama-cpp-python",
"mlx-lm",
"huggingface_hub",
"openai>=1.0.0",
"vllm",
Expand Down Expand Up @@ -110,6 +111,8 @@ module = [
"jsonschema.*",
"openai.*",
"mamba_ssm.*",
"mlx_lm.*",
"mlx.*",
"nest_asyncio",
"numpy.*",
"cloudpickle.*",
Expand Down
3 changes: 0 additions & 3 deletions tests/fsm/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,3 @@ def test_sequential_parse_example(cleanup_lark_import):

if i + 1 == len(input_tokens):
assert all(tk in next_vocab for tk in ["\n", "\nde", " ", " + 1"])

# Clean up lark.lark.LarkOptions._defaults
importlib.reload(lark.lark)
55 changes: 55 additions & 0 deletions tests/generate/test_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import re

import pytest

import outlines.generate as generate
import outlines.models as models


@pytest.fixture(scope="session")
def model_llamacpp(tmp_path_factory):
return models.llamacpp(
repo_id="M4-ai/TinyMistral-248M-v2-Instruct-GGUF",
filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf",
)


@pytest.fixture(scope="session")
def model_mlxlm(tmp_path_factory):
return models.mlxlm("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit")


@pytest.fixture(scope="session")
def model_transformers(tmp_path_factory):
return models.transformers("Locutusque/TinyMistral-248M-v2-Instruct", device="cpu")


@pytest.mark.parametrize(
"model_fixture",
("model_llamacpp", "model_mlxlm", "model_transformers"),
)
def test_generate_text(request, model_fixture):
model = request.getfixturevalue(model_fixture)
generator = generate.text(model)
res = generator("test", max_tokens=10)
assert isinstance(res, str)


@pytest.mark.parametrize(
"model_fixture",
("model_llamacpp", "model_mlxlm", "model_transformers"),
)
@pytest.mark.parametrize(
"pattern",
(
"[0-9]",
"abc*",
"\\+?[1-9][0-9]{7,14}",
r"(https?:\/\/)?([\da-z\.-]+)\.([a-z\.]{2,6})([\/\w \.-]*)*\/?",
),
)
def test_generate_json(request, model_fixture, pattern):
model = request.getfixturevalue(model_fixture)
generator = generate.text(model)
res = generator("foobarbaz", max_tokens=20)
assert re.match(pattern, res) is not None, res
33 changes: 0 additions & 33 deletions tests/generate/test_integration_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,39 +20,6 @@ def model(tmp_path_factory):
)


@pytest.mark.parametrize(
"generator_type,params",
(
(generate.text, []),
(generate.regex, ("[0-9]",)),
(generate.cfg, (grammars.arithmetic,)),
),
)
def test_llamacpp_generation_api(model, generator_type, params):
generator = generator_type(model, *params)

res = generator("test", max_tokens=10)
assert isinstance(res, str)

res = generator("test", max_tokens=10)
assert isinstance(res, str)

res = generator("test", stop_at=".")
assert isinstance(res, str)

res = generator("test", stop_at=[".", "ab"])
assert isinstance(res, str)

res = generator("test", stop_at=[".", "ab"])
assert isinstance(res, str)

res1 = generator("test", seed=1, max_tokens=10)
res2 = generator("test", seed=1, max_tokens=10)
assert isinstance(res1, str)
assert isinstance(res2, str)
assert res1 == res2


def test_llama_cpp_streaming_api(model):
generator = generate.text(model)
token_generator = generator.stream("test", max_tokens=10)
Expand Down

0 comments on commit 346e383

Please sign in to comment.