diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index faf5266b84aea3..2bea00261951c7 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -83,6 +83,7 @@ "MaxNewTokensCriteria", "MaxLengthCriteria", "MaxTimeCriteria", + "ConfidenceCriteria", "EosTokenCriteria", "StoppingCriteria", "StoppingCriteriaList", @@ -225,6 +226,7 @@ WhisperTimeStampLogitsProcessor, ) from .stopping_criteria import ( + ConfidenceCriteria, EosTokenCriteria, MaxLengthCriteria, MaxNewTokensCriteria, diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 7e4096c0aa4c80..62d5fb6eed0c49 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -108,6 +108,7 @@ def __init__( # Prepare the assistant and the starting number of candidate tokens self.assistant_model = assistant_model self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens + self.assistant_confidence_threshold = assistant_model.generation_config.assistant_confidence_threshold # Set eos in assistant same as in target model self.assistant_model.generation_config.eos_token_id = generation_config.eos_token_id @@ -157,6 +158,7 @@ def __init__( self.generation_config = copy.deepcopy(generation_config) self.generation_config.return_dict_in_generate = True self.generation_config.output_scores = True + self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold # Disable sampling -- this implementation of assisted generation/speculative decoding uses the assistant # greedily to maximize matches. Disables sampling-related flags to prevent warnings diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 4eedf2699b55b5..e2585b1b9ed49c 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -350,6 +350,11 @@ class GenerationConfig(PushToHubMixin): reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model. - `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call. - `"constant"`: `num_assistant_tokens` stays unchanged during generation + assistant_confidence_threshold (`float`, *optional*): + The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower + than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_ + (defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead + from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models . prompt_lookup_num_tokens (`int`, *optional*, default to `None`): The number of tokens to be output as candidate tokens. max_matching_ngram_size (`int`, *optional*, default to `None`): @@ -449,6 +454,7 @@ def __init__(self, **kwargs): # Assistant generation self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5) self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") + self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", None) # Prompt lookup decoding self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 7e98b11cf01a1c..b950a69f8b6492 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -467,6 +467,27 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return is_done +class ConfidenceCriteria(StoppingCriteria): + """ + This class can be used to stop generation whenever assistant model's confidence in its prediction for the current token is lower than the threshold + `model.generation_config.assistant_confidence_threshold` even if the number of speculative tokens (defined by `num_assistant_tokens`) is not yet reached. + + Args: + assistant_confidence_threshold (`float`): + The value of the threshold. + """ + + def __init__(self, assistant_confidence_threshold): + self.assistant_confidence_threshold = assistant_confidence_threshold + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: + probs = scores[-1].softmax(-1) + p = probs[0, input_ids[0, -1]].item() + if p < self.assistant_confidence_threshold: + return True + return False + + class StoppingCriteriaList(list): @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b966d72c643380..17a234c62b285e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -97,6 +97,7 @@ WatermarkLogitsProcessor, ) from .stopping_criteria import ( + ConfidenceCriteria, EosTokenCriteria, MaxLengthCriteria, MaxTimeCriteria, @@ -958,6 +959,13 @@ def _get_stopping_criteria( criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer)) if generation_config._eos_token_tensor is not None: criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor)) + if ( + generation_config.assistant_confidence_threshold is not None + and generation_config.assistant_confidence_threshold > 0 + ): + criteria.append( + ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold) + ) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) return criteria diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index a04dac96169e82..e8594dcdb07e90 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -26,6 +26,7 @@ import torch from transformers.generation import ( + ConfidenceCriteria, EosTokenCriteria, MaxLengthCriteria, MaxTimeCriteria, @@ -100,6 +101,23 @@ def test_eos_token_criteria(self): input_ids[:, -1] = 1 self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False]) + def test_confidence_criteria(self): + criteria = ConfidenceCriteria(assistant_confidence_threshold=0.5) + + vocab_size = 250 + length = 5 + + input_ids = ids_tensor((1, length), vocab_size) + scores = (torch.randn((1, vocab_size)),) + + # Simulate high confidence by setting the probability of the last token to be high + scores[0][0, input_ids[0, -1]] = 10.0 # Logits before softmax + self.assertFalse(criteria(input_ids, scores)) + + # Simulate low confidence by setting the probability of the last token to be low + scores[0][0, input_ids[0, -1]] = -10.0 # Logits before softmax + self.assertTrue(criteria(input_ids, scores)) + def test_validate_stopping_criteria(self): validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)