diff --git a/docs/reference/models/mamba.md b/docs/reference/models/mamba.md deleted file mode 100644 index ac6db3682..000000000 --- a/docs/reference/models/mamba.md +++ /dev/null @@ -1,7 +0,0 @@ -# Mamba - -```bash -pip install mamba_ssm transformers torch -``` - -*Coming soon* diff --git a/docs/reference/models/transformers.md b/docs/reference/models/transformers.md index 7c1febd02..5e8b5865e 100644 --- a/docs/reference/models/transformers.md +++ b/docs/reference/models/transformers.md @@ -82,3 +82,62 @@ print(output) ``` [transformers]: https://github.com/huggingface/transformers + + +# Alternative Model Classes + +`outlines.models.transformers` defaults to `transformers.AutoModelForCausalLM`, which is the appropriate class for most standard large language models, including Llama 3, Mistral, Phi-3, etc. + +However other variants with unique behavior can be used as well by passing the appropriate class. + +### Mamba + +[Mamba](https://github.com/state-spaces/mamba) is a transformers alternative which employs memory efficient, linear-time decoding. + +To use Mamba with outlines you must first install the necessary requirements: +``` +pip install causal-conv1d>=1.2.0 mamba-ssm torch transformers +``` + +Then you can either create an Mamba-2 Outlines model via +``` +import outlines + +model = outlines.models.mamba("state-spaces/mamba-2.8b-hf") +``` + +or explicitly with +``` +import outlines +from transformers import MambaForCausalLM + +model = outlines.models.transformers( + "state-spaces/mamba-2.8b-hf", + model_class=MambaForCausalLM +) +``` + +Further Reading: +- https://huggingface.co/docs/transformers/en/model_doc/mamba + +### Encoder-Decoder Models + +You can use encoder-decoder (seq2seq) models like T5 and BERT with Outlines. + +Be cautious with model selection though, some models such as `t5-base` don't include certain characters (`{`) and you may get an error when trying to perform structured generation. + +Example: +``` +import outlines +from transformers import AutoModelForSeq2SeqLM + +model = models.transformers( + model_name="EleutherAI/pile-t5-large", + model_class=transformers.AutoModelForSeq2SeqLM, +) +``` + + +### Multi-Modal Models + +/Coming soon/ diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index 65491ad18..660d6351c 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -9,10 +9,9 @@ from .exllamav2 import ExLlamaV2Model, exl2 from .llamacpp import LlamaCpp, llamacpp -from .mamba import Mamba, mamba from .mlxlm import MLXLM, mlxlm from .openai import OpenAI, azure_openai, openai -from .transformers import Transformers, TransformerTokenizer, transformers +from .transformers import Transformers, TransformerTokenizer, transformers, mamba from .vllm import VLLM, vllm -LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, Mamba] +LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model] diff --git a/outlines/models/mamba.py b/outlines/models/mamba.py deleted file mode 100644 index d3dabf669..000000000 --- a/outlines/models/mamba.py +++ /dev/null @@ -1,61 +0,0 @@ -from typing import TYPE_CHECKING, Optional - -from .transformers import TransformerTokenizer - -if TYPE_CHECKING: - import torch - from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel - from transformers import PreTrainedTokenizer - - -TOKENIZER_MODEL = "EleutherAI/gpt-neox-20b" - - -class Mamba: - """Represent a `mamba` model.""" - - def __init__( - self, model: "MambaLMHeadModel", tokenizer: "PreTrainedTokenizer", device - ): - self.device = device - self.model = model - self.tokenizer = TransformerTokenizer(tokenizer) - - def forward(self, input_ids: "torch.LongTensor", *_): - """Compute a forward pass through the mamba model.""" - - output = self.model(input_ids) - next_token_logits = output.logits[..., -1, :] - return next_token_logits, None - - def __call__(self, input_ids: "torch.LongTensor", *_) -> "torch.FloatTensor": - return self.forward(input_ids) - - -def mamba( - model_name: str, - device: Optional[str] = None, - model_kwargs: dict = {}, - tokenizer_kwargs: dict = {}, -): - try: - import torch - from mamba_ssm import MambaLMHeadModel - from transformers import AutoTokenizer - except ImportError: - raise ImportError( - "The `mamba_ssm`, `torch` and `transformer` libraries needs to be installed in order to use Mamba people." - ) - - if not torch.cuda.is_available(): - raise NotImplementedError("Mamba models can only run on GPU.") - else: - if device is None: - device = "cuda" - - model = MambaLMHeadModel.from_pretrained(model_name, device=device) - - tokenizer_kwargs.setdefault("padding_side", "left") - tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL, **tokenizer_kwargs) - - return Mamba(model, tokenizer, device) diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 7d32a43bd..cfd66014a 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -369,6 +369,8 @@ def transformers( device: Optional[str] = None, model_kwargs: dict = {}, tokenizer_kwargs: dict = {}, + model_class=None, + tokenizer_class=None, ): """Instantiate a model from the `transformers` library and its tokenizer. @@ -391,19 +393,47 @@ def transformers( A `TransformersModel` model instance. """ - try: - from transformers import AutoModelForCausalLM, AutoTokenizer - except ImportError: - raise ImportError( - "The `transformers` library needs to be installed in order to use `transformers` models." - ) + if model_class is None or tokenizer_class is None: + try: + from transformers import AutoModelForCausalLM, AutoTokenizer + except ImportError: + raise ImportError( + "The `transformers` library needs to be installed in order to use `transformers` models." + ) + if model_class is None: + model_class = AutoModelForCausalLM + if tokenizer_class is None: + tokenizer_class = AutoTokenizer if device is not None: model_kwargs["device_map"] = device - model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) + model = model_class.from_pretrained(model_name, **model_kwargs) tokenizer_kwargs.setdefault("padding_side", "left") - tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs) + tokenizer = tokenizer_class.from_pretrained(model_name, **tokenizer_kwargs) return Transformers(model, tokenizer) + + +def mamba( + model_name: str, + device: Optional[str] = None, + model_kwargs: dict = {}, + tokenizer_kwargs: dict = {}, +): + try: + from transformers import MambaForCausalLM + + except ImportError: + raise ImportError( + "The `mamba_ssm`, `torch` and `transformer` libraries needs to be installed in order to use Mamba people." + ) + + return transformers( + model_name=model_name, + device=device, + model_kwargs=model_kwargs, + tokenizer_kwargs=tokenizer_kwargs, + model_class=MambaForCausalLM, + ) diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index 5ceb82589..10c6e15a5 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -31,24 +31,28 @@ def model_vllm(tmp_path_factory): return models.vllm("facebook/opt-125m") -# TODO: mamba / exllamav2 failing in main, address in https://github.com/outlines-dev/outlines/issues/808 -""" @pytest.fixture(scope="session") -def model_exllamav2(tmp_path_factory): - return models.exllamav2( - model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", - device="cpu" +def model_mamba(tmp_path_factory): + return models.mamba(model_name="state-spaces/mamba-130m-hf", device="cpu") + + +@pytest.fixture(scope="session") +def model_t5(tmp_path_factory): + from transformers import T5ForConditionalGeneration + + return models.transformers( + "google/t5-efficient-mini", device="cpu", model_class=T5ForConditionalGeneration ) +# TODO: exllamav2 failing in main, address in https://github.com/outlines-dev/outlines/issues/808 +""" @pytest.fixture(scope="session") -def model_mamba(tmp_path_factory): - return models.mamba( - model_name="state-spaces/mamba-130m-hf", +def model_exllamav2(tmp_path_factory): + return models.exllamav2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", device="cpu" ) - -ALL_MODEL_FIXTURES = ("model_llamacpp", "model_mlxlm", "model_transformers", "model_vllm", "model_exllamav2", "model_mamba") """ @@ -57,6 +61,8 @@ def model_mamba(tmp_path_factory): "model_mlxlm", "model_transformers", "model_vllm", + "model_mamba", + "model_t5", )