Skip to content

Commit

Permalink
Fix mamba integration by making it a variant of outlines.models.trans…
Browse files Browse the repository at this point in the history
…formers
  • Loading branch information
lapp0 committed Jul 15, 2024
1 parent 73b84c8 commit 48b6f8f
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 114 deletions.
7 changes: 0 additions & 7 deletions docs/reference/models/mamba.md

This file was deleted.

67 changes: 67 additions & 0 deletions docs/reference/models/transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,70 @@ print(output)
```

[transformers]: https://github.com/huggingface/transformers


# Alternative Model Classes

`outlines.models.transformers` defaults to `transformers.AutoModelForCausalLM`, which is the appropriate class for most standard large language models, including Llama 3, Mistral, Phi-3, etc.

However other variants with unique behavior can be used as well by passing the appropriate class.

### Mamba

[Mamba](https://github.com/state-spaces/mamba) is a transformers alternative which employs memory efficient, linear-time decoding.

To use Mamba with outlines you must first install the necessary requirements:
```
pip install causal-conv1d>=1.2.0 mamba-ssm torch transformers
```

Then you can either create an Mamba-2 Outlines model via
```python
import outlines

model = outlines.models.mamba("state-spaces/mamba-2.8b-hf")
```

or explicitly with
```python
import outlines
from transformers import MambaForCausalLM

model = outlines.models.transformers(
"state-spaces/mamba-2.8b-hf",
model_class=MambaForCausalLM
)
```

Further Reading:
- https://huggingface.co/docs/transformers/en/model_doc/mamba

### Encoder-Decoder Models

You can use encoder-decoder (seq2seq) models like T5 and BERT with Outlines.

Be cautious with model selection though, some models such as `t5-base` don't include certain characters (`{`) and you may get an error when trying to perform structured generation.

T5 Example:
```python
import outlines
from transformers import AutoModelForSeq2SeqLM

model_pile_t5 = models.transformers(
model_name="EleutherAI/pile-t5-large",
model_class=AutoModelForSeq2SeqLM,
)
```

Bart Example:
```python
model_bart = models.transformers(
model_name="facebook/bart-large",
model_class=AutoModelForSeq2SeqLM,
)
```


### Multi-Modal Models

/Coming soon/
5 changes: 2 additions & 3 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@

from .exllamav2 import ExLlamaV2Model, exl2
from .llamacpp import LlamaCpp, llamacpp
from .mamba import Mamba, mamba
from .mlxlm import MLXLM, mlxlm
from .openai import OpenAI, azure_openai, openai
from .transformers import Transformers, TransformerTokenizer, transformers
from .transformers import Transformers, TransformerTokenizer, mamba, transformers
from .vllm import VLLM, vllm

LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, Mamba, MLXLM, VLLM]
LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, MLXLM, VLLM]
61 changes: 0 additions & 61 deletions outlines/models/mamba.py

This file was deleted.

57 changes: 47 additions & 10 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import inspect
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union

from datasets.fingerprint import Hasher
Expand Down Expand Up @@ -226,10 +227,16 @@ def generate(
input_ids, attention_mask = self.tokenizer.encode([prompts])
else:
input_ids, attention_mask = self.tokenizer.encode(prompts)

inputs = {
"input_ids": input_ids.to(self.model.device),
"attention_mask": attention_mask.to(self.model.device),
}
if (
"attention_mask"
not in inspect.signature(self.model.forward).parameters.keys()
):
del inputs["attention_mask"]

generation_kwargs = self._get_generation_kwargs(
prompts,
Expand Down Expand Up @@ -265,7 +272,7 @@ def stream(
input_ids, attention_mask = self.tokenizer.encode(prompts)
inputs = {
"input_ids": input_ids.to(self.model.device),
"attention_mask": attention_mask.to(self.model.device),
# "attention_mask": attention_mask.to(self.model.device),
}

generation_kwargs = self._get_generation_kwargs(
Expand Down Expand Up @@ -336,7 +343,7 @@ def _generate_output_seq(
):
input_ids = inputs["input_ids"]
output_ids = self.model.generate(
generation_config=generation_config, **inputs, **generation_kwargs
**inputs, generation_config=generation_config, **generation_kwargs
)

# encoder-decoder returns output_ids only, decoder-only returns full seq ids
Expand Down Expand Up @@ -376,6 +383,8 @@ def transformers(
device: Optional[str] = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
model_class=None,
tokenizer_class=None,
):
"""Instantiate a model from the `transformers` library and its tokenizer.
Expand All @@ -398,19 +407,47 @@ def transformers(
A `TransformersModel` model instance.
"""
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
raise ImportError(
"The `transformers` library needs to be installed in order to use `transformers` models."
)
if model_class is None or tokenizer_class is None:
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
raise ImportError(
"The `transformers` library needs to be installed in order to use `transformers` models."
)
if model_class is None:
model_class = AutoModelForCausalLM
if tokenizer_class is None:
tokenizer_class = AutoTokenizer

if device is not None:
model_kwargs["device_map"] = device

model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
model = model_class.from_pretrained(model_name, **model_kwargs)

tokenizer_kwargs.setdefault("padding_side", "left")
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
tokenizer = tokenizer_class.from_pretrained(model_name, **tokenizer_kwargs)

return Transformers(model, tokenizer)


def mamba(
model_name: str,
device: Optional[str] = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
):
try:
from transformers import MambaForCausalLM

except ImportError:
raise ImportError(
"The `mamba_ssm`, `torch` and `transformer` libraries needs to be installed in order to use Mamba people."
)

return transformers(
model_name=model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
model_class=MambaForCausalLM,
)
77 changes: 44 additions & 33 deletions tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def model_transformers_opt125m(tmp_path_factory):

@pytest.fixture(scope="session")
def model_vllm(tmp_path_factory):
return models.vllm("facebook/opt-125m")
return models.vllm("facebook/opt-125m", device="cpu")


@pytest.fixture(scope="session")
Expand All @@ -48,10 +48,19 @@ def model_mamba(tmp_path_factory):

@pytest.fixture(scope="session")
def model_t5(tmp_path_factory):
from transformers import T5ForConditionalGeneration
from transformers import AutoModelForSeq2SeqLM

return models.transformers(
"google/t5-efficient-mini", device="cpu", model_class=T5ForConditionalGeneration
"EleutherAI/pile-t5-base", device="cpu", model_class=AutoModelForSeq2SeqLM
)


@pytest.fixture(scope="session")
def model_bart(tmp_path_factory):
from transformers import AutoModelForSeq2SeqLM

return models.transformers(
"facebook/bart-base", device="cpu", model_class=AutoModelForSeq2SeqLM
)


Expand All @@ -75,6 +84,7 @@ def model_exllamav2(tmp_path_factory):
"model_vllm",
"model_mamba",
"model_t5",
"model_bart",
)


Expand Down Expand Up @@ -115,14 +125,13 @@ def test_generate_text(request, model_fixture, sampler_name):
assert isinstance(res, str)


@pytest.mark.parametrize("pattern", REGEX_PATTERNS)
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_batch_text(request, model_fixture):
def test_generate_regex(request, model_fixture, pattern):
model = request.getfixturevalue(model_fixture)
generator = generate.text(model)
with enforce_not_implemented(model_fixture, "batch"):
res = generator(["test", "test2"], max_tokens=10)
assert isinstance(res, list)
assert isinstance(res[0], str)
generator = generate.regex(model, pattern)
res = generator("foobarbaz", max_tokens=20)
assert re.fullmatch(pattern, res) is not None, res


@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
Expand Down Expand Up @@ -166,18 +175,14 @@ def test_generate_regex_stream(request, model_fixture, pattern):
assert re.fullmatch(pattern, output) is not None, output


@pytest.mark.parametrize("pattern", REGEX_PATTERNS)
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_regex_batch_stream(request, model_fixture, pattern):
def test_generate_batch_text(request, model_fixture):
model = request.getfixturevalue(model_fixture)
generator = generate.regex(model, pattern)
with enforce_not_implemented(model_fixture, "batch", "stream"):
outputs = ["", ""]
for tokens in generator.stream(["input 0", "input 1"], max_tokens=20):
outputs[0] += tokens[0]
outputs[1] += tokens[1]
for output in outputs:
assert re.fullmatch(pattern, output) is not None, output
generator = generate.text(model)
with enforce_not_implemented(model_fixture, "batch"):
res = generator(["test", "test2"], max_tokens=10)
assert isinstance(res, list)
assert isinstance(res[0], str)


@pytest.mark.parametrize("pattern", REGEX_PATTERNS)
Expand All @@ -194,35 +199,41 @@ def test_generate_regex_batch(request, model_fixture, pattern):

@pytest.mark.parametrize("pattern", REGEX_PATTERNS)
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_regex_single_multinomial(request, model_fixture, pattern):
"""Ensure batch requests work and fsm order is maintained"""
def test_generate_regex_batch_stream(request, model_fixture, pattern):
model = request.getfixturevalue(model_fixture)
generator = generate.regex(model, pattern, sampler=samplers.multinomial(4))
with enforce_not_implemented(model_fixture, "multiple_samples"):
output_sample_groups = generator("single input", max_tokens=40)
for output in output_sample_groups:
generator = generate.regex(model, pattern)
with enforce_not_implemented(model_fixture, "batch", "stream"):
outputs = ["", ""]
for tokens in generator.stream(["input 0", "input 1"], max_tokens=20):
outputs[0] += tokens[0]
outputs[1] += tokens[1]
for output in outputs:
assert re.fullmatch(pattern, output) is not None, output


@pytest.mark.parametrize("pattern", REGEX_PATTERNS)
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_regex_batch_multinomial(request, model_fixture, pattern):
def test_generate_regex_single_multinomial(request, model_fixture, pattern):
"""Ensure batch requests work and fsm order is maintained"""
model = request.getfixturevalue(model_fixture)
generator = generate.regex(model, pattern, sampler=samplers.multinomial(4))
with enforce_not_implemented(model_fixture, "batch", "multiple_samples"):
output_batch_groups = generator(["abc", "123", "123bce", "33aa"], max_tokens=40)
for output_sample_groups in output_batch_groups:
for output in output_sample_groups:
assert re.fullmatch(pattern, output) is not None, output
with enforce_not_implemented(model_fixture, "multiple_samples"):
output_sample_groups = generator("single input", max_tokens=40)
for output in output_sample_groups:
assert re.fullmatch(pattern, output) is not None, output


@pytest.mark.parametrize("pattern", REGEX_PATTERNS)
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_regex_batch_beam_search(request, model_fixture, pattern):
@pytest.mark.parametrize("sampler_name", ("multinomial", "beam_search"))
def test_generate_regex_batch_multi_sample(
request, model_fixture, pattern, sampler_name
):
"""Ensure batch requests work and fsm order is maintained"""
model = request.getfixturevalue(model_fixture)
generator = generate.regex(model, pattern, sampler=samplers.beam_search(4))
generator = generate.regex(
model, pattern, sampler=getattr(samplers, sampler_name)(4)
)
with enforce_not_implemented(model_fixture, "batch", "multiple_samples"):
output_batch_groups = generator(["abc", "123", "123bce", "33aa"], max_tokens=40)
for output_sample_groups in output_batch_groups:
Expand Down

0 comments on commit 48b6f8f

Please sign in to comment.