Skip to content

Commit

Permalink
Pipeline: simple API for assisted generation (#34504)
Browse files Browse the repository at this point in the history
Co-authored-by: Matt <[email protected]>
  • Loading branch information
gante and Rocketknight1 authored Jan 8, 2025
1 parent 3f483be commit 76da6ca
Show file tree
Hide file tree
Showing 14 changed files with 172 additions and 18 deletions.
22 changes: 22 additions & 0 deletions docs/source/en/generation_strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,28 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```

<Tip>

If you're using a `pipeline` object, all you need to do is to pass the assistant checkpoint under `assistant_model`

```python
>>> from transformers import pipeline
>>> import torch

>>> pipe = pipeline(
... "text-generation",
... model="meta-llama/Llama-3.1-8B",
... assistant_model="meta-llama/Llama-3.2-1B", # This extra line is all that's needed, also works with UAD
... torch_dtype=torch.bfloat16
>>> )
>>> pipe_output = pipe("Once upon a time, ", max_new_tokens=50, do_sample=False)
>>> pipe_output[0]["generated_text"]
'Once upon a time, 3D printing was a niche technology that was only'
```

</Tip>


When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.

Expand Down
1 change: 0 additions & 1 deletion src/transformers/generation/flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,6 @@ def generate(
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id

if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder:
Expand Down
1 change: 0 additions & 1 deletion src/transformers/generation/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,6 @@ def generate(
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id

use_xla = not tf.executing_eagerly()
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ def _sanitize_parameters(
raise ValueError("Only Whisper can return language for now.")
postprocess_params["return_language"] = return_language

if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer

return preprocess_params, forward_params, postprocess_params

def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
Expand Down
67 changes: 64 additions & 3 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..image_processing_utils import BaseImageProcessor
from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig
from ..models.auto import AutoConfig, AutoTokenizer
from ..processing_utils import ProcessorMixin
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import (
Expand Down Expand Up @@ -425,6 +425,62 @@ def get_default_model_and_revision(
return default_models[framework]


def load_assistant_model(
model: "PreTrainedModel",
assistant_model: Optional[Union[str, "PreTrainedModel"]],
assistant_tokenizer: Optional[PreTrainedTokenizer],
) -> Tuple[Optional["PreTrainedModel"], Optional[PreTrainedTokenizer]]:
"""
Prepares the assistant model and the assistant tokenizer for a pipeline whose model that can call `generate`.
Args:
model ([`PreTrainedModel`]):
The main model that will be used by the pipeline to make predictions.
assistant_model (`str` or [`PreTrainedModel`], *optional*):
The assistant model that will be used by the pipeline to make predictions.
assistant_tokenizer ([`PreTrainedTokenizer`], *optional*):
The assistant tokenizer that will be used by the pipeline to encode data for the model.
Returns:
Tuple: The loaded assistant model and (optionally) the loaded tokenizer.
"""
if not model.can_generate() or assistant_model is None:
return None, None

if not isinstance(model, PreTrainedModel):
raise ValueError(
"Assisted generation, triggered by the `assistant_model` argument, is only available for "
"`PreTrainedModel` model instances. For instance, TF or JAX models are not supported."
)

# If the model is passed as a string, load the model and the corresponding tokenizer
if isinstance(assistant_model, str):
assistant_config = AutoConfig.from_pretrained(assistant_model)
_, loaded_assistant_model = infer_framework_load_model(assistant_model, config=assistant_config)
loaded_assistant_model = loaded_assistant_model.to(device=model.device, dtype=model.dtype)
loaded_assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_model)
else:
loaded_assistant_model = assistant_model
loaded_assistant_tokenizer = assistant_tokenizer

# Finally, let's check the tokenizers: if the two models have different tokenizers, we need to keep the assistant
# tokenizer
same_vocab_size = model.config.vocab_size == loaded_assistant_model.config.vocab_size
same_special_tokens = all(
getattr(model.config, token) == getattr(loaded_assistant_model.config, token)
for token in ("eos_token_id", "pad_token_id", "bos_token_id")
)
if same_vocab_size and same_special_tokens:
loaded_assistant_tokenizer = None
elif loaded_assistant_tokenizer is None:
raise ValueError(
"The assistant model has a different tokenizer than the main model. You should pass the assistant "
"tokenizer."
)

return loaded_assistant_model, loaded_assistant_tokenizer


class PipelineException(Exception):
"""
Raised by a [`Pipeline`] when handling __call__.
Expand Down Expand Up @@ -925,8 +981,13 @@ def __init__(
):
self.model.to(self.device)

# If the model can generate, create a local generation config. This is done to avoid side-effects on the model
# as we apply local tweaks to the generation config.
# If the model can generate:
# 1 - create a local generation config. This is done to avoid side-effects on the model as we apply local
# tweaks to the generation config.
# 2 - load the assistant model if it is passed.
self.assistant_model, self.assistant_tokenizer = load_assistant_model(
self.model, kwargs.pop("assistant_model", None), kwargs.pop("assistant_tokenizer", None)
)
if self.model.can_generate():
self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None
self.generation_config = copy.deepcopy(self.model.generation_config)
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/pipelines/document_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,14 @@ def _sanitize_parameters(
if handle_impossible_answer is not None:
postprocess_params["handle_impossible_answer"] = handle_impossible_answer

return preprocess_params, {}, postprocess_params
forward_params = {}
if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer

return preprocess_params, forward_params, postprocess_params

def __call__(
self,
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/pipelines/image_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt
)
forward_params.update(generate_kwargs)

if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer

return preprocess_params, forward_params, {}

def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/pipelines/table_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,13 @@ def _sanitize_parameters(self, sequential=None, padding=None, truncation=None, *
forward_params = {}
if sequential is not None:
forward_params["sequential"] = sequential

if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer

return preprocess_params, forward_params, {}

def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=None):
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/pipelines/text2text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ def _sanitize_parameters(
)
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]

if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer

return preprocess_params, forward_params, postprocess_params

def check_inputs(self, input_length: int, min_length: int, max_length: int):
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import enum
import itertools
import types
import warnings
from typing import Dict

from ..utils import add_end_docstrings, is_tf_available, is_torch_available
Expand Down Expand Up @@ -194,12 +193,13 @@ def _sanitize_parameters(

if stop_sequence is not None:
stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
if len(stop_sequence_ids) > 1:
warnings.warn(
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
" the stop sequence will be used as the stop sequence string in the interim."
)
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
generate_kwargs["eos_token_id"] = stop_sequence_ids

if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer

return preprocess_params, forward_params, postprocess_params

Expand Down
13 changes: 9 additions & 4 deletions src/transformers/pipelines/text_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,9 @@ def _forward(self, model_inputs, **kwargs):
else:
if len(generate_kwargs):
raise ValueError(
f"""You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non empty.
For forward-only TTA models, please use `forward_params` instead of of
`generate_kwargs`. For reference, here are the `generate_kwargs` used here:
{generate_kwargs.keys()}"""
"You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non "
"empty. For forward-only TTA models, please use `forward_params` instead of `generate_kwargs`. "
f"For reference, the `generate_kwargs` used here are: {generate_kwargs.keys()}"
)
output = self.model(**model_inputs, **forward_params)[0]

Expand Down Expand Up @@ -191,6 +190,12 @@ def _sanitize_parameters(
forward_params=None,
generate_kwargs=None,
):
if self.assistant_model is not None:
generate_kwargs["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
generate_kwargs["tokenizer"] = self.tokenizer
generate_kwargs["assistant_tokenizer"] = self.assistant_tokenizer

params = {
"forward_params": forward_params if forward_params else {},
"generate_kwargs": generate_kwargs if generate_kwargs else {},
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/pipelines/visual_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,15 @@ def _sanitize_parameters(self, top_k=None, padding=None, truncation=None, timeou
preprocess_params["timeout"] = timeout
if top_k is not None:
postprocess_params["top_k"] = top_k
return preprocess_params, {}, postprocess_params

forward_params = {}
if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer

return preprocess_params, forward_params, postprocess_params

def __call__(
self,
Expand Down
14 changes: 14 additions & 0 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1933,6 +1933,20 @@ def test_slow_unfinished_sequence(self):
},
)

@require_torch
def test_pipeline_assisted_generation(self):
"""Tests that we can run assisted generation in the pipeline"""
model = "openai/whisper-tiny"
pipe = pipeline("automatic-speech-recognition", model=model, assistant_model=model)

# We can run the pipeline
prompt = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")["audio"]
_ = pipe(prompt)

# It is running assisted generation under the hood (e.g. flags incompatible with assisted gen will crash)
with self.assertRaises(ValueError):
_ = pipe(prompt, generate_kwargs={"num_beams": 2})


def require_ffmpeg(test_case):
"""
Expand Down
14 changes: 14 additions & 0 deletions tests/pipelines/test_pipelines_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,3 +653,17 @@ def test_pipeline_length_setting_warning(self):
with CaptureLogger(logger) as cl:
_ = text_generator(prompt, max_length=10)
self.assertNotIn(logger_msg, cl.out)

@require_torch
def test_pipeline_assisted_generation(self):
"""Tests that we can run assisted generation in the pipeline"""
model = "hf-internal-testing/tiny-random-MistralForCausalLM"
pipe = pipeline("text-generation", model=model, assistant_model=model)

# We can run the pipeline
prompt = "Hello world"
_ = pipe(prompt)

# It is running assisted generation under the hood (e.g. flags incompatible with assisted gen will crash)
with self.assertRaises(ValueError):
_ = pipe(prompt, generate_kwargs={"num_beams": 2})

0 comments on commit 76da6ca

Please sign in to comment.