diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index 757f42cdf54f40..4c4c19a3d6628d 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -441,6 +441,28 @@ To enable assisted decoding, set the `assistant_model` argument with a model. ['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] ``` + + +If you're using a `pipeline` object, all you need to do is to pass the assistant checkpoint under `assistant_model` + +```python +>>> from transformers import pipeline +>>> import torch + +>>> pipe = pipeline( +... "text-generation", +... model="meta-llama/Llama-3.1-8B", +... assistant_model="meta-llama/Llama-3.2-1B", # This extra line is all that's needed, also works with UAD +... torch_dtype=torch.bfloat16 +>>> ) +>>> pipe_output = pipe("Once upon a time, ", max_new_tokens=50, do_sample=False) +>>> pipe_output[0]["generated_text"] +'Once upon a time, 3D printing was a niche technology that was only' +``` + + + + When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness, just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency. diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 8e87ead7fdd5a9..134666d45a3a9e 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -347,7 +347,6 @@ def generate( eos_token_id = generation_config.eos_token_id if isinstance(eos_token_id, list): eos_token_id = eos_token_id[0] - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") generation_config.pad_token_id = eos_token_id if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder: diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index eb7396a8b5dc77..dd1f819fe548aa 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -773,7 +773,6 @@ def generate( eos_token_id = generation_config.eos_token_id if isinstance(eos_token_id, list): eos_token_id = eos_token_id[0] - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") generation_config.pad_token_id = eos_token_id use_xla = not tf.executing_eagerly() diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 09958b5fca195b..66a9c49ea5f351 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -348,6 +348,12 @@ def _sanitize_parameters( raise ValueError("Only Whisper can return language for now.") postprocess_params["return_language"] = return_language + if self.assistant_model is not None: + forward_params["assistant_model"] = self.assistant_model + if self.assistant_tokenizer is not None: + forward_params["tokenizer"] = self.tokenizer + forward_params["assistant_tokenizer"] = self.assistant_tokenizer + return preprocess_params, forward_params, postprocess_params def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index d2d4f198d41847..a24e9c3f697878 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -33,7 +33,7 @@ from ..feature_extraction_utils import PreTrainedFeatureExtractor from ..image_processing_utils import BaseImageProcessor from ..modelcard import ModelCard -from ..models.auto.configuration_auto import AutoConfig +from ..models.auto import AutoConfig, AutoTokenizer from ..processing_utils import ProcessorMixin from ..tokenization_utils import PreTrainedTokenizer from ..utils import ( @@ -425,6 +425,62 @@ def get_default_model_and_revision( return default_models[framework] +def load_assistant_model( + model: "PreTrainedModel", + assistant_model: Optional[Union[str, "PreTrainedModel"]], + assistant_tokenizer: Optional[PreTrainedTokenizer], +) -> Tuple[Optional["PreTrainedModel"], Optional[PreTrainedTokenizer]]: + """ + Prepares the assistant model and the assistant tokenizer for a pipeline whose model that can call `generate`. + + Args: + model ([`PreTrainedModel`]): + The main model that will be used by the pipeline to make predictions. + assistant_model (`str` or [`PreTrainedModel`], *optional*): + The assistant model that will be used by the pipeline to make predictions. + assistant_tokenizer ([`PreTrainedTokenizer`], *optional*): + The assistant tokenizer that will be used by the pipeline to encode data for the model. + + Returns: + Tuple: The loaded assistant model and (optionally) the loaded tokenizer. + """ + if not model.can_generate() or assistant_model is None: + return None, None + + if not isinstance(model, PreTrainedModel): + raise ValueError( + "Assisted generation, triggered by the `assistant_model` argument, is only available for " + "`PreTrainedModel` model instances. For instance, TF or JAX models are not supported." + ) + + # If the model is passed as a string, load the model and the corresponding tokenizer + if isinstance(assistant_model, str): + assistant_config = AutoConfig.from_pretrained(assistant_model) + _, loaded_assistant_model = infer_framework_load_model(assistant_model, config=assistant_config) + loaded_assistant_model = loaded_assistant_model.to(device=model.device, dtype=model.dtype) + loaded_assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_model) + else: + loaded_assistant_model = assistant_model + loaded_assistant_tokenizer = assistant_tokenizer + + # Finally, let's check the tokenizers: if the two models have different tokenizers, we need to keep the assistant + # tokenizer + same_vocab_size = model.config.vocab_size == loaded_assistant_model.config.vocab_size + same_special_tokens = all( + getattr(model.config, token) == getattr(loaded_assistant_model.config, token) + for token in ("eos_token_id", "pad_token_id", "bos_token_id") + ) + if same_vocab_size and same_special_tokens: + loaded_assistant_tokenizer = None + elif loaded_assistant_tokenizer is None: + raise ValueError( + "The assistant model has a different tokenizer than the main model. You should pass the assistant " + "tokenizer." + ) + + return loaded_assistant_model, loaded_assistant_tokenizer + + class PipelineException(Exception): """ Raised by a [`Pipeline`] when handling __call__. @@ -925,8 +981,13 @@ def __init__( ): self.model.to(self.device) - # If the model can generate, create a local generation config. This is done to avoid side-effects on the model - # as we apply local tweaks to the generation config. + # If the model can generate: + # 1 - create a local generation config. This is done to avoid side-effects on the model as we apply local + # tweaks to the generation config. + # 2 - load the assistant model if it is passed. + self.assistant_model, self.assistant_tokenizer = load_assistant_model( + self.model, kwargs.pop("assistant_model", None), kwargs.pop("assistant_tokenizer", None) + ) if self.model.can_generate(): self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None self.generation_config = copy.deepcopy(self.model.generation_config) diff --git a/src/transformers/pipelines/document_question_answering.py b/src/transformers/pipelines/document_question_answering.py index 9198f432263822..c176d841e29fa6 100644 --- a/src/transformers/pipelines/document_question_answering.py +++ b/src/transformers/pipelines/document_question_answering.py @@ -189,7 +189,14 @@ def _sanitize_parameters( if handle_impossible_answer is not None: postprocess_params["handle_impossible_answer"] = handle_impossible_answer - return preprocess_params, {}, postprocess_params + forward_params = {} + if self.assistant_model is not None: + forward_params["assistant_model"] = self.assistant_model + if self.assistant_tokenizer is not None: + forward_params["tokenizer"] = self.tokenizer + forward_params["assistant_tokenizer"] = self.assistant_tokenizer + + return preprocess_params, forward_params, postprocess_params def __call__( self, diff --git a/src/transformers/pipelines/image_to_text.py b/src/transformers/pipelines/image_to_text.py index afd67b6ac9edee..32a3ec218dac30 100644 --- a/src/transformers/pipelines/image_to_text.py +++ b/src/transformers/pipelines/image_to_text.py @@ -92,6 +92,12 @@ def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt ) forward_params.update(generate_kwargs) + if self.assistant_model is not None: + forward_params["assistant_model"] = self.assistant_model + if self.assistant_tokenizer is not None: + forward_params["tokenizer"] = self.tokenizer + forward_params["assistant_tokenizer"] = self.assistant_tokenizer + return preprocess_params, forward_params, {} def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs): diff --git a/src/transformers/pipelines/table_question_answering.py b/src/transformers/pipelines/table_question_answering.py index 77c95432c7218f..10ea7170fed40c 100644 --- a/src/transformers/pipelines/table_question_answering.py +++ b/src/transformers/pipelines/table_question_answering.py @@ -358,6 +358,13 @@ def _sanitize_parameters(self, sequential=None, padding=None, truncation=None, * forward_params = {} if sequential is not None: forward_params["sequential"] = sequential + + if self.assistant_model is not None: + forward_params["assistant_model"] = self.assistant_model + if self.assistant_tokenizer is not None: + forward_params["tokenizer"] = self.tokenizer + forward_params["assistant_tokenizer"] = self.assistant_tokenizer + return preprocess_params, forward_params, {} def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=None): diff --git a/src/transformers/pipelines/text2text_generation.py b/src/transformers/pipelines/text2text_generation.py index 75ded8ac085ca5..9bc7544550286e 100644 --- a/src/transformers/pipelines/text2text_generation.py +++ b/src/transformers/pipelines/text2text_generation.py @@ -106,6 +106,12 @@ def _sanitize_parameters( ) generate_kwargs["eos_token_id"] = stop_sequence_ids[0] + if self.assistant_model is not None: + forward_params["assistant_model"] = self.assistant_model + if self.assistant_tokenizer is not None: + forward_params["tokenizer"] = self.tokenizer + forward_params["assistant_tokenizer"] = self.assistant_tokenizer + return preprocess_params, forward_params, postprocess_params def check_inputs(self, input_length: int, min_length: int, max_length: int): diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index e15228fbe0aa6e..c0f14663ffdf58 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -1,7 +1,6 @@ import enum import itertools import types -import warnings from typing import Dict from ..utils import add_end_docstrings, is_tf_available, is_torch_available @@ -194,12 +193,13 @@ def _sanitize_parameters( if stop_sequence is not None: stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False) - if len(stop_sequence_ids) > 1: - warnings.warn( - "Stopping on a multiple token sequence is not yet supported on transformers. The first token of" - " the stop sequence will be used as the stop sequence string in the interim." - ) - generate_kwargs["eos_token_id"] = stop_sequence_ids[0] + generate_kwargs["eos_token_id"] = stop_sequence_ids + + if self.assistant_model is not None: + forward_params["assistant_model"] = self.assistant_model + if self.assistant_tokenizer is not None: + forward_params["tokenizer"] = self.tokenizer + forward_params["assistant_tokenizer"] = self.assistant_tokenizer return preprocess_params, forward_params, postprocess_params diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index d17d18205920b0..b7beca586d2195 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -148,10 +148,9 @@ def _forward(self, model_inputs, **kwargs): else: if len(generate_kwargs): raise ValueError( - f"""You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non empty. - For forward-only TTA models, please use `forward_params` instead of of - `generate_kwargs`. For reference, here are the `generate_kwargs` used here: - {generate_kwargs.keys()}""" + "You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non " + "empty. For forward-only TTA models, please use `forward_params` instead of `generate_kwargs`. " + f"For reference, the `generate_kwargs` used here are: {generate_kwargs.keys()}" ) output = self.model(**model_inputs, **forward_params)[0] @@ -191,6 +190,12 @@ def _sanitize_parameters( forward_params=None, generate_kwargs=None, ): + if self.assistant_model is not None: + generate_kwargs["assistant_model"] = self.assistant_model + if self.assistant_tokenizer is not None: + generate_kwargs["tokenizer"] = self.tokenizer + generate_kwargs["assistant_tokenizer"] = self.assistant_tokenizer + params = { "forward_params": forward_params if forward_params else {}, "generate_kwargs": generate_kwargs if generate_kwargs else {}, diff --git a/src/transformers/pipelines/visual_question_answering.py b/src/transformers/pipelines/visual_question_answering.py index 89988c0cba2b1b..6d600c9eaf50bc 100644 --- a/src/transformers/pipelines/visual_question_answering.py +++ b/src/transformers/pipelines/visual_question_answering.py @@ -66,7 +66,15 @@ def _sanitize_parameters(self, top_k=None, padding=None, truncation=None, timeou preprocess_params["timeout"] = timeout if top_k is not None: postprocess_params["top_k"] = top_k - return preprocess_params, {}, postprocess_params + + forward_params = {} + if self.assistant_model is not None: + forward_params["assistant_model"] = self.assistant_model + if self.assistant_tokenizer is not None: + forward_params["tokenizer"] = self.tokenizer + forward_params["assistant_tokenizer"] = self.assistant_tokenizer + + return preprocess_params, forward_params, postprocess_params def __call__( self, diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index e8cd8febca006e..da57a002c4f5bc 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -1933,6 +1933,20 @@ def test_slow_unfinished_sequence(self): }, ) + @require_torch + def test_pipeline_assisted_generation(self): + """Tests that we can run assisted generation in the pipeline""" + model = "openai/whisper-tiny" + pipe = pipeline("automatic-speech-recognition", model=model, assistant_model=model) + + # We can run the pipeline + prompt = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")["audio"] + _ = pipe(prompt) + + # It is running assisted generation under the hood (e.g. flags incompatible with assisted gen will crash) + with self.assertRaises(ValueError): + _ = pipe(prompt, generate_kwargs={"num_beams": 2}) + def require_ffmpeg(test_case): """ diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 7de84e646e192d..d5014586b331c1 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -653,3 +653,17 @@ def test_pipeline_length_setting_warning(self): with CaptureLogger(logger) as cl: _ = text_generator(prompt, max_length=10) self.assertNotIn(logger_msg, cl.out) + + @require_torch + def test_pipeline_assisted_generation(self): + """Tests that we can run assisted generation in the pipeline""" + model = "hf-internal-testing/tiny-random-MistralForCausalLM" + pipe = pipeline("text-generation", model=model, assistant_model=model) + + # We can run the pipeline + prompt = "Hello world" + _ = pipe(prompt) + + # It is running assisted generation under the hood (e.g. flags incompatible with assisted gen will crash) + with self.assertRaises(ValueError): + _ = pipe(prompt, generate_kwargs={"num_beams": 2})