diff --git a/docs/reference/models/mamba.md b/docs/reference/models/mamba.md deleted file mode 100644 index ac6db3682..000000000 --- a/docs/reference/models/mamba.md +++ /dev/null @@ -1,7 +0,0 @@ -# Mamba - -```bash -pip install mamba_ssm transformers torch -``` - -*Coming soon* diff --git a/docs/reference/models/transformers.md b/docs/reference/models/transformers.md index 7c1febd02..80b70845f 100644 --- a/docs/reference/models/transformers.md +++ b/docs/reference/models/transformers.md @@ -82,3 +82,70 @@ print(output) ``` [transformers]: https://github.com/huggingface/transformers + + +# Alternative Model Classes + +`outlines.models.transformers` defaults to `transformers.AutoModelForCausalLM`, which is the appropriate class for most standard large language models, including Llama 3, Mistral, Phi-3, etc. + +However other variants with unique behavior can be used as well by passing the appropriate class. + +### Mamba + +[Mamba](https://github.com/state-spaces/mamba) is a transformers alternative which employs memory efficient, linear-time decoding. + +To use Mamba with outlines you must first install the necessary requirements: +``` +pip install causal-conv1d>=1.2.0 mamba-ssm torch transformers +``` + +Then you can either create an Mamba-2 Outlines model via +```python +import outlines + +model = outlines.models.mamba("state-spaces/mamba-2.8b-hf") +``` + +or explicitly with +```python +import outlines +from transformers import MambaForCausalLM + +model = outlines.models.transformers( + "state-spaces/mamba-2.8b-hf", + model_class=MambaForCausalLM +) +``` + +Further Reading: +- https://huggingface.co/docs/transformers/en/model_doc/mamba + +### Encoder-Decoder Models + +You can use encoder-decoder (seq2seq) models like T5 and BART with Outlines. + +Be cautious with model selection though, some models such as `t5-base` don't include certain characters (`{`) and you may get an error when trying to perform structured generation. + +T5 Example: +```python +import outlines +from transformers import AutoModelForSeq2SeqLM + +model_pile_t5 = models.transformers( + model_name="EleutherAI/pile-t5-large", + model_class=AutoModelForSeq2SeqLM, +) +``` + +Bart Example: +```python +model_bart = models.transformers( + model_name="facebook/bart-large", + model_class=AutoModelForSeq2SeqLM, +) +``` + + +### Multi-Modal Models + +/Coming soon/ diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index fde913e2c..c161215d1 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -9,10 +9,9 @@ from .exllamav2 import ExLlamaV2Model, exl2 from .llamacpp import LlamaCpp, llamacpp -from .mamba import Mamba, mamba from .mlxlm import MLXLM, mlxlm from .openai import OpenAI, azure_openai, openai -from .transformers import Transformers, TransformerTokenizer, transformers +from .transformers import Transformers, TransformerTokenizer, mamba, transformers from .vllm import VLLM, vllm -LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, Mamba, MLXLM, VLLM] +LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, MLXLM, VLLM] diff --git a/outlines/models/mamba.py b/outlines/models/mamba.py deleted file mode 100644 index d3dabf669..000000000 --- a/outlines/models/mamba.py +++ /dev/null @@ -1,61 +0,0 @@ -from typing import TYPE_CHECKING, Optional - -from .transformers import TransformerTokenizer - -if TYPE_CHECKING: - import torch - from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel - from transformers import PreTrainedTokenizer - - -TOKENIZER_MODEL = "EleutherAI/gpt-neox-20b" - - -class Mamba: - """Represent a `mamba` model.""" - - def __init__( - self, model: "MambaLMHeadModel", tokenizer: "PreTrainedTokenizer", device - ): - self.device = device - self.model = model - self.tokenizer = TransformerTokenizer(tokenizer) - - def forward(self, input_ids: "torch.LongTensor", *_): - """Compute a forward pass through the mamba model.""" - - output = self.model(input_ids) - next_token_logits = output.logits[..., -1, :] - return next_token_logits, None - - def __call__(self, input_ids: "torch.LongTensor", *_) -> "torch.FloatTensor": - return self.forward(input_ids) - - -def mamba( - model_name: str, - device: Optional[str] = None, - model_kwargs: dict = {}, - tokenizer_kwargs: dict = {}, -): - try: - import torch - from mamba_ssm import MambaLMHeadModel - from transformers import AutoTokenizer - except ImportError: - raise ImportError( - "The `mamba_ssm`, `torch` and `transformer` libraries needs to be installed in order to use Mamba people." - ) - - if not torch.cuda.is_available(): - raise NotImplementedError("Mamba models can only run on GPU.") - else: - if device is None: - device = "cuda" - - model = MambaLMHeadModel.from_pretrained(model_name, device=device) - - tokenizer_kwargs.setdefault("padding_side", "left") - tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL, **tokenizer_kwargs) - - return Mamba(model, tokenizer, device) diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 4c475617a..1d3be511a 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -1,4 +1,5 @@ import dataclasses +import inspect from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union from datasets.fingerprint import Hasher @@ -226,10 +227,16 @@ def generate( input_ids, attention_mask = self.tokenizer.encode([prompts]) else: input_ids, attention_mask = self.tokenizer.encode(prompts) + inputs = { "input_ids": input_ids.to(self.model.device), "attention_mask": attention_mask.to(self.model.device), } + if ( + "attention_mask" + not in inspect.signature(self.model.forward).parameters.keys() + ): + del inputs["attention_mask"] generation_kwargs = self._get_generation_kwargs( prompts, @@ -267,6 +274,11 @@ def stream( "input_ids": input_ids.to(self.model.device), "attention_mask": attention_mask.to(self.model.device), } + if ( + "attention_mask" + not in inspect.signature(self.model.forward).parameters.keys() + ): + del inputs["attention_mask"] generation_kwargs = self._get_generation_kwargs( prompts, @@ -336,7 +348,7 @@ def _generate_output_seq( ): input_ids = inputs["input_ids"] output_ids = self.model.generate( - generation_config=generation_config, **inputs, **generation_kwargs + **inputs, generation_config=generation_config, **generation_kwargs ) # encoder-decoder returns output_ids only, decoder-only returns full seq ids @@ -376,6 +388,8 @@ def transformers( device: Optional[str] = None, model_kwargs: dict = {}, tokenizer_kwargs: dict = {}, + model_class=None, + tokenizer_class=None, ): """Instantiate a model from the `transformers` library and its tokenizer. @@ -398,19 +412,47 @@ def transformers( A `TransformersModel` model instance. """ - try: - from transformers import AutoModelForCausalLM, AutoTokenizer - except ImportError: - raise ImportError( - "The `transformers` library needs to be installed in order to use `transformers` models." - ) + if model_class is None or tokenizer_class is None: + try: + from transformers import AutoModelForCausalLM, AutoTokenizer + except ImportError: + raise ImportError( + "The `transformers` library needs to be installed in order to use `transformers` models." + ) + if model_class is None: + model_class = AutoModelForCausalLM + if tokenizer_class is None: + tokenizer_class = AutoTokenizer if device is not None: model_kwargs["device_map"] = device - model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) + model = model_class.from_pretrained(model_name, **model_kwargs) tokenizer_kwargs.setdefault("padding_side", "left") - tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs) + tokenizer = tokenizer_class.from_pretrained(model_name, **tokenizer_kwargs) return Transformers(model, tokenizer) + + +def mamba( + model_name: str, + device: Optional[str] = None, + model_kwargs: dict = {}, + tokenizer_kwargs: dict = {}, +): + try: + from transformers import MambaForCausalLM + + except ImportError: + raise ImportError( + "The `mamba_ssm`, `torch` and `transformer` libraries needs to be installed in order to use Mamba." + ) + + return transformers( + model_name=model_name, + device=device, + model_kwargs=model_kwargs, + tokenizer_kwargs=tokenizer_kwargs, + model_class=MambaForCausalLM, + ) diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index 9432c0e4e..d433c48fe 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -36,17 +36,54 @@ def model_transformers_opt125m(tmp_path_factory): return models.transformers("facebook/opt-125m", device="cpu") +@pytest.fixture(scope="session") +def model_mamba(tmp_path_factory): + return models.mamba(model_name="state-spaces/mamba-130m-hf", device="cpu") + + +@pytest.fixture(scope="session") +def model_bart(tmp_path_factory): + from transformers import AutoModelForSeq2SeqLM + + return models.transformers( + "facebook/bart-base", device="cpu", model_class=AutoModelForSeq2SeqLM + ) + + +# TODO: exllamav2 failing in main, address in https://github.com/outlines-dev/outlines/issues/808 +# TODO: t5 tokenizer doesn't work with streaming +""" +@pytest.fixture(scope="session") +def model_exllamav2(tmp_path_factory): + return models.exllamav2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + device="cpu" + ) + +@pytest.fixture(scope="session") +def model_t5(tmp_path_factory): + from transformers import AutoModelForSeq2SeqLM + + return models.transformers( + "EleutherAI/pile-t5-base", device="cpu", model_class=AutoModelForSeq2SeqLM + ) +""" + + ALL_MODEL_FIXTURES = ( "model_llamacpp", "model_mlxlm", "model_mlxlm_phi3", "model_transformers_random", "model_transformers_opt125m", + "model_mamba", + "model_t5", + "model_bart", ) NOT_IMPLEMENTED = { - "stream": ["model_vllm"], + "stream": [], "batch": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"], "beam_search": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"], "multiple_samples": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"],