Skip to content

Commit

Permalink
Fix mamba integration by making it a variant of outlines.models.trans…
Browse files Browse the repository at this point in the history
…formers
  • Loading branch information
lapp0 committed Jun 14, 2024
1 parent 1537695 commit 84ea1eb
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 90 deletions.
7 changes: 0 additions & 7 deletions docs/reference/models/mamba.md

This file was deleted.

59 changes: 59 additions & 0 deletions docs/reference/models/transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,62 @@ print(output)
```

[transformers]: https://github.com/huggingface/transformers


# Alternative Model Classes

`outlines.models.transformers` defaults to `transformers.AutoModelForCausalLM`, which is the appropriate class for most standard large language models, including Llama 3, Mistral, Phi-3, etc.

However other variants with unique behavior can be used as well by passing the appropriate class.

### Mamba

[Mamba](https://github.com/state-spaces/mamba) is a transformers alternative which employs memory efficient, linear-time decoding.

To use Mamba with outlines you must first install the necessary requirements:
```
pip install causal-conv1d>=1.2.0 mamba-ssm torch transformers
```

Then you can either create an Mamba-2 Outlines model via
```
import outlines
model = outlines.models.mamba("state-spaces/mamba-2.8b-hf")
```

or explicitly with
```
import outlines
from transformers import MambaForCausalLM
model = outlines.models.transformers(
"state-spaces/mamba-2.8b-hf",
model_class=MambaForCausalLM
)
```

Further Reading:
- https://huggingface.co/docs/transformers/en/model_doc/mamba

### Encoder-Decoder Models

You can use encoder-decoder (seq2seq) models like T5 and BERT with Outlines.

Be cautious with model selection though, some models such as `t5-base` don't include certain characters (`{`) and you may get an error when trying to perform structured generation.

Example:
```
import outlines
from transformers import AutoModelForSeq2SeqLM
model = models.transformers(
model_name="EleutherAI/pile-t5-large",
model_class=transformers.AutoModelForSeq2SeqLM,
)
```


### Multi-Modal Models

/Coming soon/
5 changes: 2 additions & 3 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@

from .exllamav2 import ExLlamaV2Model, exl2
from .llamacpp import LlamaCpp, llamacpp
from .mamba import Mamba, mamba
from .mlxlm import MLXLM, mlxlm
from .openai import OpenAI, azure_openai, openai
from .transformers import Transformers, TransformerTokenizer, transformers
from .transformers import Transformers, TransformerTokenizer, transformers, mamba
from .vllm import VLLM, vllm

LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, Mamba]
LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model]
61 changes: 0 additions & 61 deletions outlines/models/mamba.py

This file was deleted.

46 changes: 38 additions & 8 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ def transformers(
device: Optional[str] = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
model_class=None,
tokenizer_class=None,
):
"""Instantiate a model from the `transformers` library and its tokenizer.
Expand All @@ -391,19 +393,47 @@ def transformers(
A `TransformersModel` model instance.
"""
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
raise ImportError(
"The `transformers` library needs to be installed in order to use `transformers` models."
)
if model_class is None or tokenizer_class is None:
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
raise ImportError(
"The `transformers` library needs to be installed in order to use `transformers` models."
)
if model_class is None:
model_class = AutoModelForCausalLM
if tokenizer_class is None:
tokenizer_class = AutoTokenizer

if device is not None:
model_kwargs["device_map"] = device

model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
model = model_class.from_pretrained(model_name, **model_kwargs)

tokenizer_kwargs.setdefault("padding_side", "left")
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
tokenizer = tokenizer_class.from_pretrained(model_name, **tokenizer_kwargs)

return Transformers(model, tokenizer)


def mamba(
model_name: str,
device: Optional[str] = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
):
try:
from transformers import MambaForCausalLM

except ImportError:
raise ImportError(
"The `mamba_ssm`, `torch` and `transformer` libraries needs to be installed in order to use Mamba people."
)

return transformers(
model_name=model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
model_class=MambaForCausalLM,
)
28 changes: 17 additions & 11 deletions tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,28 @@ def model_vllm(tmp_path_factory):
return models.vllm("facebook/opt-125m")


# TODO: mamba / exllamav2 failing in main, address in https://github.com/outlines-dev/outlines/issues/808
"""
@pytest.fixture(scope="session")
def model_exllamav2(tmp_path_factory):
return models.exllamav2(
model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2",
device="cpu"
def model_mamba(tmp_path_factory):
return models.mamba(model_name="state-spaces/mamba-130m-hf", device="cpu")


@pytest.fixture(scope="session")
def model_t5(tmp_path_factory):
from transformers import T5ForConditionalGeneration

return models.transformers(
"google/t5-efficient-mini", device="cpu", model_class=T5ForConditionalGeneration
)


# TODO: exllamav2 failing in main, address in https://github.com/outlines-dev/outlines/issues/808
"""
@pytest.fixture(scope="session")
def model_mamba(tmp_path_factory):
return models.mamba(
model_name="state-spaces/mamba-130m-hf",
def model_exllamav2(tmp_path_factory):
return models.exllamav2(
model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2",
device="cpu"
)
ALL_MODEL_FIXTURES = ("model_llamacpp", "model_mlxlm", "model_transformers", "model_vllm", "model_exllamav2", "model_mamba")
"""


Expand All @@ -57,6 +61,8 @@ def model_mamba(tmp_path_factory):
"model_mlxlm",
"model_transformers",
"model_vllm",
"model_mamba",
"model_t5",
)


Expand Down

0 comments on commit 84ea1eb

Please sign in to comment.