Skip to content

Commit

Permalink
Dynamic number of speculative tokens in order to accelerate speculati…
Browse files Browse the repository at this point in the history
…ve decoding (huggingface#33258)

* optimal Speculation Lookahead based on probability

* update peer finished condition

* add support to do_sample True

* add stopping criteria

* gitignore

* add print

* remove prints

* minor

* minor

* git ignore

* adding test to stopping ConfidenceCriteria

* doc + format

* add doc

* Update .gitignore

* update docstring and default value of assistant_confidence_threshold

* add docstring

* Update src/transformers/generation/configuration_utils.py

implicit default value (None)

Co-authored-by: Joao Gante <[email protected]>

* style fix

---------

Co-authored-by: Joao Gante <[email protected]>
  • Loading branch information
jmamou and gante authored Sep 11, 2024
1 parent 42babe8 commit 7a51cbc
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"MaxNewTokensCriteria",
"MaxLengthCriteria",
"MaxTimeCriteria",
"ConfidenceCriteria",
"EosTokenCriteria",
"StoppingCriteria",
"StoppingCriteriaList",
Expand Down Expand Up @@ -225,6 +226,7 @@
WhisperTimeStampLogitsProcessor,
)
from .stopping_criteria import (
ConfidenceCriteria,
EosTokenCriteria,
MaxLengthCriteria,
MaxNewTokensCriteria,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
# Prepare the assistant and the starting number of candidate tokens
self.assistant_model = assistant_model
self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
self.assistant_confidence_threshold = assistant_model.generation_config.assistant_confidence_threshold

# Set eos in assistant same as in target model
self.assistant_model.generation_config.eos_token_id = generation_config.eos_token_id
Expand Down Expand Up @@ -157,6 +158,7 @@ def __init__(
self.generation_config = copy.deepcopy(generation_config)
self.generation_config.return_dict_in_generate = True
self.generation_config.output_scores = True
self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold

# Disable sampling -- this implementation of assisted generation/speculative decoding uses the assistant
# greedily to maximize matches. Disables sampling-related flags to prevent warnings
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,11 @@ class GenerationConfig(PushToHubMixin):
reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model.
- `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call.
- `"constant"`: `num_assistant_tokens` stays unchanged during generation
assistant_confidence_threshold (`float`, *optional*):
The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower
than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
(defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead
from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models <https://arxiv.org/abs/2405.04304>.
prompt_lookup_num_tokens (`int`, *optional*, default to `None`):
The number of tokens to be output as candidate tokens.
max_matching_ngram_size (`int`, *optional*, default to `None`):
Expand Down Expand Up @@ -449,6 +454,7 @@ def __init__(self, **kwargs):
# Assistant generation
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")
self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", None)

# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,27 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
return is_done


class ConfidenceCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever assistant model's confidence in its prediction for the current token is lower than the threshold
`model.generation_config.assistant_confidence_threshold` even if the number of speculative tokens (defined by `num_assistant_tokens`) is not yet reached.
Args:
assistant_confidence_threshold (`float`):
The value of the threshold.
"""

def __init__(self, assistant_confidence_threshold):
self.assistant_confidence_threshold = assistant_confidence_threshold

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
probs = scores[-1].softmax(-1)
p = probs[0, input_ids[0, -1]].item()
if p < self.assistant_confidence_threshold:
return True
return False


class StoppingCriteriaList(list):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
WatermarkLogitsProcessor,
)
from .stopping_criteria import (
ConfidenceCriteria,
EosTokenCriteria,
MaxLengthCriteria,
MaxTimeCriteria,
Expand Down Expand Up @@ -958,6 +959,13 @@ def _get_stopping_criteria(
criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
if generation_config._eos_token_tensor is not None:
criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
if (
generation_config.assistant_confidence_threshold is not None
and generation_config.assistant_confidence_threshold > 0
):
criteria.append(
ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold)
)
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
return criteria

Expand Down
18 changes: 18 additions & 0 deletions tests/generation/test_stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch

from transformers.generation import (
ConfidenceCriteria,
EosTokenCriteria,
MaxLengthCriteria,
MaxTimeCriteria,
Expand Down Expand Up @@ -100,6 +101,23 @@ def test_eos_token_criteria(self):
input_ids[:, -1] = 1
self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False])

def test_confidence_criteria(self):
criteria = ConfidenceCriteria(assistant_confidence_threshold=0.5)

vocab_size = 250
length = 5

input_ids = ids_tensor((1, length), vocab_size)
scores = (torch.randn((1, vocab_size)),)

# Simulate high confidence by setting the probability of the last token to be high
scores[0][0, input_ids[0, -1]] = 10.0 # Logits before softmax
self.assertFalse(criteria(input_ids, scores))

# Simulate low confidence by setting the probability of the last token to be low
scores[0][0, input_ids[0, -1]] = -10.0 # Logits before softmax
self.assertTrue(criteria(input_ids, scores))

def test_validate_stopping_criteria(self):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)

Expand Down

0 comments on commit 7a51cbc

Please sign in to comment.