Skip to content

Commit

Permalink
Fix mamba integration by making it a variant of outlines.models.trans…
Browse files Browse the repository at this point in the history
…formers
  • Loading branch information
lapp0 committed Jul 15, 2024
1 parent 5a7f082 commit 75dc370
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 81 deletions.
7 changes: 0 additions & 7 deletions docs/reference/models/mamba.md

This file was deleted.

67 changes: 67 additions & 0 deletions docs/reference/models/transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,70 @@ print(output)
```

[transformers]: https://github.com/huggingface/transformers


# Alternative Model Classes

`outlines.models.transformers` defaults to `transformers.AutoModelForCausalLM`, which is the appropriate class for most standard large language models, including Llama 3, Mistral, Phi-3, etc.

However other variants with unique behavior can be used as well by passing the appropriate class.

### Mamba

[Mamba](https://github.com/state-spaces/mamba) is a transformers alternative which employs memory efficient, linear-time decoding.

To use Mamba with outlines you must first install the necessary requirements:
```
pip install causal-conv1d>=1.2.0 mamba-ssm torch transformers
```

Then you can either create an Mamba-2 Outlines model via
```python
import outlines

model = outlines.models.mamba("state-spaces/mamba-2.8b-hf")
```

or explicitly with
```python
import outlines
from transformers import MambaForCausalLM

model = outlines.models.transformers(
"state-spaces/mamba-2.8b-hf",
model_class=MambaForCausalLM
)
```

Further Reading:
- https://huggingface.co/docs/transformers/en/model_doc/mamba

### Encoder-Decoder Models

You can use encoder-decoder (seq2seq) models like T5 and BART with Outlines.

Be cautious with model selection though, some models such as `t5-base` don't include certain characters (`{`) and you may get an error when trying to perform structured generation.

T5 Example:
```python
import outlines
from transformers import AutoModelForSeq2SeqLM

model_pile_t5 = models.transformers(
model_name="EleutherAI/pile-t5-large",
model_class=AutoModelForSeq2SeqLM,
)
```

Bart Example:
```python
model_bart = models.transformers(
model_name="facebook/bart-large",
model_class=AutoModelForSeq2SeqLM,
)
```


### Multi-Modal Models

/Coming soon/
5 changes: 2 additions & 3 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@

from .exllamav2 import ExLlamaV2Model, exl2
from .llamacpp import LlamaCpp, llamacpp
from .mamba import Mamba, mamba
from .mlxlm import MLXLM, mlxlm
from .openai import OpenAI, azure_openai, openai
from .transformers import Transformers, TransformerTokenizer, transformers
from .transformers import Transformers, TransformerTokenizer, mamba, transformers
from .vllm import VLLM, vllm

LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, Mamba, MLXLM, VLLM]
LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, MLXLM, VLLM]
61 changes: 0 additions & 61 deletions outlines/models/mamba.py

This file was deleted.

60 changes: 51 additions & 9 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import inspect
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union

from datasets.fingerprint import Hasher
Expand Down Expand Up @@ -226,10 +227,16 @@ def generate(
input_ids, attention_mask = self.tokenizer.encode([prompts])
else:
input_ids, attention_mask = self.tokenizer.encode(prompts)

inputs = {
"input_ids": input_ids.to(self.model.device),
"attention_mask": attention_mask.to(self.model.device),
}
if (
"attention_mask"
not in inspect.signature(self.model.forward).parameters.keys()
):
del inputs["attention_mask"]

generation_kwargs = self._get_generation_kwargs(
prompts,
Expand Down Expand Up @@ -267,6 +274,11 @@ def stream(
"input_ids": input_ids.to(self.model.device),
"attention_mask": attention_mask.to(self.model.device),
}
if (
"attention_mask"
not in inspect.signature(self.model.forward).parameters.keys()
):
del inputs["attention_mask"]

generation_kwargs = self._get_generation_kwargs(
prompts,
Expand Down Expand Up @@ -336,7 +348,7 @@ def _generate_output_seq(
):
input_ids = inputs["input_ids"]
output_ids = self.model.generate(
generation_config=generation_config, **inputs, **generation_kwargs
**inputs, generation_config=generation_config, **generation_kwargs
)

# encoder-decoder returns output_ids only, decoder-only returns full seq ids
Expand Down Expand Up @@ -376,6 +388,8 @@ def transformers(
device: Optional[str] = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
model_class=None,
tokenizer_class=None,
):
"""Instantiate a model from the `transformers` library and its tokenizer.
Expand All @@ -398,19 +412,47 @@ def transformers(
A `TransformersModel` model instance.
"""
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
raise ImportError(
"The `transformers` library needs to be installed in order to use `transformers` models."
)
if model_class is None or tokenizer_class is None:
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
raise ImportError(
"The `transformers` library needs to be installed in order to use `transformers` models."
)
if model_class is None:
model_class = AutoModelForCausalLM
if tokenizer_class is None:
tokenizer_class = AutoTokenizer

if device is not None:
model_kwargs["device_map"] = device

model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
model = model_class.from_pretrained(model_name, **model_kwargs)

tokenizer_kwargs.setdefault("padding_side", "left")
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
tokenizer = tokenizer_class.from_pretrained(model_name, **tokenizer_kwargs)

return Transformers(model, tokenizer)


def mamba(
model_name: str,
device: Optional[str] = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
):
try:
from transformers import MambaForCausalLM

except ImportError:
raise ImportError(
"The `mamba_ssm`, `torch` and `transformer` libraries needs to be installed in order to use Mamba."
)

return transformers(
model_name=model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
model_class=MambaForCausalLM,
)
39 changes: 38 additions & 1 deletion tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,54 @@ def model_transformers_opt125m(tmp_path_factory):
return models.transformers("facebook/opt-125m", device="cpu")


@pytest.fixture(scope="session")
def model_mamba(tmp_path_factory):
return models.mamba(model_name="state-spaces/mamba-130m-hf", device="cpu")


@pytest.fixture(scope="session")
def model_bart(tmp_path_factory):
from transformers import AutoModelForSeq2SeqLM

return models.transformers(
"facebook/bart-base", device="cpu", model_class=AutoModelForSeq2SeqLM
)


# TODO: exllamav2 failing in main, address in https://github.com/outlines-dev/outlines/issues/808
# TODO: t5 tokenizer doesn't work with streaming
"""
@pytest.fixture(scope="session")
def model_exllamav2(tmp_path_factory):
return models.exllamav2(
model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2",
device="cpu"
)
@pytest.fixture(scope="session")
def model_t5(tmp_path_factory):
from transformers import AutoModelForSeq2SeqLM
return models.transformers(
"EleutherAI/pile-t5-base", device="cpu", model_class=AutoModelForSeq2SeqLM
)
"""


ALL_MODEL_FIXTURES = (
"model_llamacpp",
"model_mlxlm",
"model_mlxlm_phi3",
"model_transformers_random",
"model_transformers_opt125m",
"model_mamba",
"model_t5",
"model_bart",
)


NOT_IMPLEMENTED = {
"stream": ["model_vllm"],
"stream": [],
"batch": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"],
"beam_search": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"],
"multiple_samples": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"],
Expand Down

0 comments on commit 75dc370

Please sign in to comment.