diff --git a/docs/reference/models/transformers.md b/docs/reference/models/transformers.md index 2d9880a6b..286df4367 100644 --- a/docs/reference/models/transformers.md +++ b/docs/reference/models/transformers.md @@ -3,7 +3,11 @@ !!! Installation - You need to install the `transformer` and `datasets` libraries to be able to use these models in Outlines. + You need to install the `transformer`, `datasets` and `torch` libraries to be able to use these models in Outlines: + + ```bash + pip install torch transformers datasets + ``` Outlines provides an integration with the `torch` implementation of causal models in the [transformers][transformers] library. You can initialize the model by passing its name: diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 4fa0a3e79..3f4f182d2 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -1,12 +1,13 @@ import datetime from dataclasses import dataclass -from typing import Iterator, List, Optional, Union - -import torch +from typing import TYPE_CHECKING, Iterator, List, Optional, Union from outlines.generate.generator import sequence_generator from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler +if TYPE_CHECKING: + import torch + FormattedOutput = Union[ str, int, float, bool, datetime.date, datetime.time, datetime.datetime ] @@ -29,9 +30,9 @@ def __init__( def get_generated_token_ids( self, - prompt_token_ids: torch.Tensor, - token_ids: torch.Tensor, - ) -> List[torch.Tensor]: + prompt_token_ids: "torch.Tensor", + token_ids: "torch.Tensor", + ) -> List["torch.Tensor"]: """Get the tokens generated so far. Parameters @@ -130,7 +131,7 @@ def __call__( prompts: Union[str, List[str]], max_tokens: Optional[int] = None, stop_at: Optional[Union[str, List[str]]] = None, - rng: Optional[torch.Generator] = None, + rng: Optional["torch.Generator"] = None, ) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]: """Generate the full text sequence. @@ -157,6 +158,7 @@ def __call__( ------- The generation(s), potentially cast to another type. """ + import torch if isinstance(prompts, str): prompts = [prompts] @@ -247,7 +249,7 @@ def stream( prompts: Union[str, List[str]], max_tokens: Optional[int] = None, stop_at: Optional[Union[str, List[str]]] = None, - rng: Optional[torch.Generator] = None, + rng: Optional["torch.Generator"] = None, ) -> Iterator[Union[List[str], str, List[List[str]]]]: """Generate the text sequence one token at a time. @@ -274,6 +276,7 @@ def stream( A string or list of strings that contain the generated text. """ + import torch if isinstance(prompts, str): prompts = [prompts] diff --git a/outlines/generate/generator.py b/outlines/generate/generator.py index ad8ae8537..12c30e588 100644 --- a/outlines/generate/generator.py +++ b/outlines/generate/generator.py @@ -2,9 +2,9 @@ import math from typing import TYPE_CHECKING, Callable, Iterator, List, Optional, Tuple -import torch - if TYPE_CHECKING: + import torch + from outlines.fsm.guide import Guide @@ -29,7 +29,7 @@ def sequence_generator( sequence_weights: torch.Tensor, attention_masks: torch.Tensor, fsm_states: List[int], - rng: torch.Generator = torch.Generator(), + rng: "torch.Generator", ) -> Iterator[GenerationState]: """Generates sequences of tokens. @@ -62,6 +62,11 @@ def sequence_generator( A new sequence. """ + import torch + + if rng is None: + rng = torch.Generator() + kv_cache = None while True: @@ -107,7 +112,7 @@ def sequence_generator( def get_next_fsm_states( - fsms: List["Guide"], fsm_states: List[int], next_token_ids: torch.Tensor + fsms: List["Guide"], fsm_states: List[int], next_token_ids: "torch.Tensor" ) -> List[int]: """ @@ -129,7 +134,7 @@ def get_next_fsm_states( ] -def get_allowed_tokens(fsms: List["Guide"], fsm_states: List[int]) -> torch.Tensor: +def get_allowed_tokens(fsms: List["Guide"], fsm_states: List[int]) -> "torch.Tensor": """Get the new instructions for each sequence from the finite-state machine. Parameters @@ -173,10 +178,9 @@ def is_generation_finished(fsms: List["Guide"], fsm_states: List[int]) -> bool: return all([fsm.is_final_state(state) for fsm, state in zip(fsms, fsm_states)]) -@torch.inference_mode() def update_token_ids( - token_ids: torch.Tensor, next_token_ids: torch.Tensor, ancestors: torch.Tensor -) -> torch.Tensor: + token_ids: "torch.Tensor", next_token_ids: "torch.Tensor", ancestors: "torch.Tensor" +) -> "torch.Tensor": """Append the sampled tokens to the running sequence of tokens. Parameters @@ -195,14 +199,15 @@ def update_token_ids( just generated. """ + import torch + token_ids = torch.index_select(token_ids, 0, ancestors) return torch.concatenate([token_ids, next_token_ids], dim=-1) -@torch.inference_mode() def update_attention_masks( - attention_masks: torch.Tensor, ancestors: torch.Tensor -) -> torch.Tensor: + attention_masks: "torch.Tensor", ancestors: "torch.Tensor" +) -> "torch.Tensor": """Expand the attention masks. Parameters @@ -217,6 +222,8 @@ def update_attention_masks( The attention masks padded with 1s. """ + import torch + attention_masks = torch.index_select(attention_masks, 0, ancestors) return torch.concatenate( [ @@ -229,7 +236,7 @@ def update_attention_masks( ) -def reorder_fsms(fsms: List["Guide"], ancestors: torch.Tensor) -> List["Guide"]: +def reorder_fsms(fsms: List["Guide"], ancestors: "torch.Tensor") -> List["Guide"]: reordered_fsms = [] for ancestor in ancestors: reordered_fsms.append(fsms[ancestor].copy()) @@ -237,7 +244,7 @@ def reorder_fsms(fsms: List["Guide"], ancestors: torch.Tensor) -> List["Guide"]: return reordered_fsms -def reorder_fsm_states(fsm_states: List[int], ancestors: torch.Tensor) -> List[int]: +def reorder_fsm_states(fsm_states: List[int], ancestors: "torch.Tensor") -> List[int]: reordered_states = [] for ancestor in ancestors: reordered_states.append(fsm_states[ancestor]) @@ -246,7 +253,7 @@ def reorder_fsm_states(fsm_states: List[int], ancestors: torch.Tensor) -> List[i def reorder_kv_cache( - kv_cache: Optional[Tuple], ancestors: torch.Tensor + kv_cache: Optional[Tuple], ancestors: "torch.Tensor" ) -> Optional[Tuple]: """Re-order the KV-cache based on the ancestors. @@ -270,8 +277,7 @@ def reorder_kv_cache( return new_kv_cache -@torch.inference_mode() -def bias_logits(logits: torch.Tensor, allowed_token_ids: List) -> torch.Tensor: +def bias_logits(logits: "torch.Tensor", allowed_token_ids: List) -> "torch.Tensor": """Mask the logits. The function iterates over a nested list where each list corresponds to the @@ -290,6 +296,8 @@ def bias_logits(logits: torch.Tensor, allowed_token_ids: List) -> torch.Tensor: A view of the original logits tensor where some values are masked. """ + import torch + biased_logits = torch.full_like(logits, -math.inf, device=logits.device) for i, ids in enumerate(allowed_token_ids): biased_logits[i, ids] = logits[i, ids] diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 1b29ee2f4..c7ec0bdb1 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -1,10 +1,9 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union -import torch - from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: + import torch from transformers import PreTrainedModel, PreTrainedTokenizer __all__ = ["transformers"] @@ -77,13 +76,13 @@ def __init__(self, tokenizer: "PreTrainedTokenizer", **kwargs): def encode( self, prompt: Union[str, List[str]], **kwargs - ) -> Tuple[torch.LongTensor, torch.LongTensor]: + ) -> Tuple["torch.LongTensor", "torch.LongTensor"]: kwargs["padding"] = True kwargs["return_tensors"] = "pt" output = self.tokenizer(prompt, **kwargs) return output["input_ids"], output["attention_mask"] - def decode(self, token_ids: torch.LongTensor) -> List[str]: + def decode(self, token_ids: "torch.LongTensor") -> List[str]: text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) return text @@ -127,13 +126,12 @@ def __init__( self.model = model self.tokenizer = TransformerTokenizer(tokenizer) - @torch.inference_mode def forward( self, - input_ids: torch.LongTensor, - attention_mask: torch.LongTensor, + input_ids: "torch.LongTensor", + attention_mask: "torch.LongTensor", past_key_values: Optional[Tuple] = None, - ) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]: + ) -> Tuple["torch.FloatTensor", Optional[KVCacheType]]: """Compute a forward pass through the transformer model. Parameters @@ -151,28 +149,35 @@ def forward( The computed logits and the new cached key and value tensors. """ + try: + import torch + except ImportError: + ImportError( + "The `torch` library needs to be installed to use `transformers` models." + ) assert 0 < input_ids.ndim < 3 if past_key_values: input_ids = input_ids[..., -1].unsqueeze(-1) - output = self.model( - input_ids, - attention_mask=attention_mask, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - past_key_values=past_key_values, - ) + with torch.inference_mode(): + output = self.model( + input_ids, + attention_mask=attention_mask, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + past_key_values=past_key_values, + ) return output.logits, output.past_key_values def __call__( self, - input_ids: torch.LongTensor, - attention_mask: torch.LongTensor, + input_ids: "torch.LongTensor", + attention_mask: "torch.LongTensor", past_key_values: Optional[Tuple] = None, - ) -> torch.FloatTensor: + ) -> "torch.FloatTensor": logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values) next_token_logits = logits[..., -1, :] diff --git a/pyproject.toml b/pyproject.toml index 3137f281f..5a0cc6986 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,12 +31,10 @@ dependencies = [ "cloudpickle", "diskcache", "pydantic>=2.0", - "torch>=2.1.0", "numba", "referencing", "jsonschema", "requests", - "transformers", ] dynamic = ["version"] @@ -47,7 +45,6 @@ test = [ "pytest-benchmark", "pytest-cov", "pytest-mock", - "transformers", "coverage[toml]>=5.1", "diff-cover", "accelerate", @@ -57,7 +54,9 @@ test = [ "llama-cpp-python", "huggingface_hub", "openai>=1.0.0", - "vllm" + "vllm", + "torch", + "transformers", ] serve = [ "vllm>=0.3.0",