Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Universal Speculative Decoding CandidateGenerator #35029

Open
wants to merge 72 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
aa7e01a
move `TestAssistedCandidateGeneratorDifferentTokenizers` into a new t…
keyboardAnt Nov 28, 2024
f6b7f20
refactor
keyboardAnt Nov 28, 2024
0ded37c
NOTHING. add space to rerun github actions tests
keyboardAnt Nov 28, 2024
d48b69b
remove it...
keyboardAnt Nov 28, 2024
b47e33a
`UniversalSpeculativeDecodingGenerator`
keyboardAnt Nov 16, 2024
8a99129
Use `UniversalSpeculativeDecodingGenerator` when `generation_config.d…
keyboardAnt Nov 16, 2024
4649bd2
assistant tokenizes only the target's new suffix
keyboardAnt Nov 16, 2024
f199c94
formatting
keyboardAnt Nov 16, 2024
19c0057
fix code
jmamou Nov 21, 2024
acf5a4b
fix code
jmamou Nov 24, 2024
3712117
formatting
keyboardAnt Nov 24, 2024
63f2f46
add `TestGenerateWithDifferentModels`
keyboardAnt Nov 24, 2024
6ac33f1
`TestGenerateWithDifferentModels` parameterize on `do_sample`
keyboardAnt Nov 24, 2024
6938311
`AssistantVocabMapping` & `AssistantVocabMappingCache`
keyboardAnt Nov 24, 2024
5a0db3b
formatting
keyboardAnt Nov 24, 2024
92f8ad3
`AssistantToTargetTranslator`: `get_target_input_ids` & `get_target_l…
keyboardAnt Nov 24, 2024
7c8708e
improve `_get_assistant_to_target_input_ids` & formatting
keyboardAnt Nov 24, 2024
880d0ae
renaming
keyboardAnt Nov 24, 2024
d9b5e74
WIP: debugging `min_new_tokens`
keyboardAnt Nov 25, 2024
25974d5
fix get_target_ids
jmamou Nov 25, 2024
b8636ab
`UniversalSpeculativeDecodingGenerator`
keyboardAnt Nov 16, 2024
1ef46b7
assistant tokenizes only the target's new suffix
keyboardAnt Nov 16, 2024
f8e94eb
formatting
keyboardAnt Nov 16, 2024
439db84
fix code
jmamou Nov 21, 2024
643901d
fix code
jmamou Nov 24, 2024
77097ff
formatting
keyboardAnt Nov 24, 2024
d08b4f0
`TestGenerateWithDifferentModels` parameterize on `do_sample`
keyboardAnt Nov 24, 2024
f242dc1
`AssistantVocabMapping` & `AssistantVocabMappingCache`
keyboardAnt Nov 24, 2024
ede1176
formatting
keyboardAnt Nov 24, 2024
511ee96
`AssistantToTargetTranslator`: `get_target_input_ids` & `get_target_l…
keyboardAnt Nov 24, 2024
5e47945
improve `_get_assistant_to_target_input_ids` & formatting
keyboardAnt Nov 24, 2024
25a4349
renaming
keyboardAnt Nov 24, 2024
95fe744
WIP: debugging `min_new_tokens`
keyboardAnt Nov 25, 2024
0ad88b2
fix get_target_ids
jmamou Nov 25, 2024
bc5fa61
fix device issue
jmamou Nov 25, 2024
41a5670
fix get_assistant_input_ids
jmamou Nov 25, 2024
44f7ba7
add `TestAssistedCandidateGeneratorDifferentTokenizers`
keyboardAnt Nov 26, 2024
57aafcc
formatting
keyboardAnt Nov 26, 2024
6f95c33
`AssistantVocabTranslatorCache` refactor & tests
keyboardAnt Nov 26, 2024
078f763
revert changes in `src/transformers/generation/logits_process.py`
keyboardAnt Nov 26, 2024
faac2fc
refactor `AssistedCandidateGenerator`
keyboardAnt Nov 26, 2024
76a2dd3
refactor `AssistedCandidateGeneratorDifferentTokenizers`
keyboardAnt Nov 26, 2024
43e96e7
formatting
keyboardAnt Nov 26, 2024
e63cb9d
refactor `UniversalSpeculativeDecodingGenerator`
keyboardAnt Nov 26, 2024
8aa6020
fix negative value for max_new_tokens
jmamou Nov 26, 2024
2169973
fix generation length target + attention_mask vs. assistant + attent
jmamou Nov 26, 2024
c6da827
fix device
jmamou Nov 26, 2024
2cf9e8e
fix negative max_new_tokens bug
jmamou Nov 27, 2024
a1c0d05
fix UAG
jmamou Nov 28, 2024
d830091
minor
jmamou Nov 28, 2024
19d0cce
formatting
keyboardAnt Nov 28, 2024
5b8217d
`AssistedCandidateGeneratorDifferentTokenizers` `lookbehind`s init
keyboardAnt Nov 28, 2024
9b0126a
resolve conflict & formatting
keyboardAnt Nov 30, 2024
578d0b3
rerun CI tests
keyboardAnt Nov 30, 2024
7db2695
remove space...
keyboardAnt Nov 30, 2024
fb69900
remove old code
keyboardAnt Dec 3, 2024
e40c775
fix candidate_input_ids device
jmamou Dec 4, 2024
b5ce873
minor
jmamou Dec 4, 2024
bfccdea
Merge pull request #4 from keyboardAnt/fix_device
keyboardAnt Dec 5, 2024
d34d7ea
formatting
keyboardAnt Dec 5, 2024
9d4d9f9
Fix prepare + apply (#7)
jmamou Dec 17, 2024
4e92e9c
Add unittests for Universal Assisted generation
gauravj14 Dec 12, 2024
3fe2d31
Merge branch 'main' into usd
jmamou Dec 18, 2024
a350b1c
fix style
jmamou Dec 18, 2024
e047adf
update tests
jmamou Dec 18, 2024
011f595
Remove unused import and fix `test_speculation_depth` test
gauravjain14 Dec 17, 2024
2652490
exclude special and reserved tokens from tokenizer for UAG
gauravjain14 Dec 18, 2024
701edbb
mv `test_universal_assisted_generation.py` to `generation/test_candid…
gauravjain14 Dec 19, 2024
7088978
Merge pull request #8 from keyboardAnt/unit_tests_usd
gauravjain14 Dec 19, 2024
3b89341
Remove unused imports and fix style using `make style` (#9)
gauravjain14 Dec 20, 2024
e43dba8
formatting
keyboardAnt Dec 21, 2024
a529795
Swap gated `meta-llama/llama-3.2` with `allenai/llama` (#10)
gauravjain14 Dec 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 250 additions & 3 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.

import copy
import threading
import weakref
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple

import numpy as np
Expand All @@ -27,7 +29,11 @@

from ..cache_utils import DynamicCache
from ..pytorch_utils import isin_mps_friendly
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
from .logits_process import (
LogitsProcessorList,
MinLengthLogitsProcessor,
SuppressTokensLogitsProcessor,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -284,18 +290,21 @@ def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> Tuple[int, int]:
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
return min_new_tokens, max_new_tokens

def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool:
def _update_past_and_masks(
self, input_ids: torch.LongTensor, remove_from_pkv: int = 0, num_added_tokens: int = 1
) -> bool:
"""Update past key values and attention masks for subsequent generation rounds."""
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens
)
self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1])

return has_past_key_values

def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict:
Expand Down Expand Up @@ -609,6 +618,244 @@ def _process_assistant_outputs(
return new_target_ids


class AssistantToTargetTranslator:
"""
Translate the assistant into the target universe.
"""

def __init__(
self,
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
assistant_model_device,
target_vocab_size: int,
filter_value: float = -float("Inf"),
suppress_tokens_id: int = -1,
):
self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer
self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer
self._assistant_model_device = assistant_model_device
self.target_vocab_size: int = target_vocab_size
self.filter_value: float = filter_value
self.suppress_tokens_id: int = suppress_tokens_id
self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids()
self.logits_processors: LogitsProcessorList = LogitsProcessorList(
[SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)]
)

def _get_assistant_to_target_input_ids(self):
target_vocab = self._target_tokenizer.get_vocab()
assistant_vocab = self._assistant_tokenizer.get_vocab()
max_assistant_index = max(assistant_vocab.values())
assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.suppress_tokens_id, dtype=int)
for tok, idx in assistant_vocab.items():
if tok in target_vocab:
assistant_to_target_input_ids[idx] = target_vocab[tok]
return assistant_to_target_input_ids.to(self._assistant_model_device)

def _get_suppress_input_ids(self) -> list[int]:
"""
Get the input ids that are in the assistant vocab but not in the target vocab.
"""
return torch.where(self._assistant_to_target_input_ids == self.suppress_tokens_id)[0]

def get_target_ids(
self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor
) -> torch.LongTensor:
"""
Return the target candidate ids that correspond to the assistant candidate ids.
Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens.
Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids.
"""

num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]
if num_new_tokens == 0:
return target_input_ids
else:
transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]]
return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1)

def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor:
"""
Return the target logits that correspond to the assistant logits.
"""

target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size)
target_logits: torch.FloatTensor = torch.full(target_shape, self.filter_value).to(self._assistant_model_device)
# Mask for valid indices
assistant_indices_mask = self._assistant_to_target_input_ids != self.suppress_tokens_id
# Exclude invalid indices
target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask]
valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]]

target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask]

return target_logits


class AssistantVocabTranslatorCache:
"""
Cache for `AssistantToTargetTranslator` instances. The instances are computed at
pre-processing time, and this cache allows us to avoid recomputing them.
"""

_lock = threading.Lock()
_cache = weakref.WeakKeyDictionary()

@classmethod
def get_translator(
cls,
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
assistant_model_device,
target_vocab_size: int,
) -> AssistantToTargetTranslator:
with cls._lock:
assistant_dict = cls._cache.get(target_tokenizer)
if assistant_dict is None:
assistant_dict = weakref.WeakKeyDictionary()
cls._cache[target_tokenizer] = assistant_dict

mapping = assistant_dict.get(assistant_tokenizer)
if mapping is None:
mapping = AssistantToTargetTranslator(
target_tokenizer, assistant_tokenizer, assistant_model_device, target_vocab_size
)
assistant_dict[assistant_tokenizer] = mapping

return mapping

@classmethod
def cleanup(cls):
"""
Clean up dead references in the cache.
This removes entries where either the target_tokenizer or assistant_tokenizer
has been garbage collected.
"""
with cls._lock:
# Remove entries from the outer cache where the target_tokenizer is no longer alive
dead_keys = [key for key in cls._cache if key is None]
for key in dead_keys:
del cls._cache[key]

# For each assistant_dict, remove entries where assistant_tokenizer is no longer alive
for assistant_dict in cls._cache.values():
dead_keys = [key for key in assistant_dict if key is None]
for key in dead_keys:
del assistant_dict[key]


class UniversalSpeculativeDecodingGenerator(AssistedCandidateGeneratorDifferentTokenizers):
"""
`CandidateGenerator` class to be used for Universal Speculative Decoding (USD): speculative decoding with different tokenizers
for the assistant and main models. This class generates candidates through the use of a smaller model.
"""

def __init__(
self,
input_ids: torch.LongTensor,
assistant_model: "PreTrainedModel",
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
generation_config: "GenerationConfig",
model_kwargs: Dict,
target_vocab_size: int,
inputs_tensor: Optional[torch.Tensor] = None,
logits_processor: "LogitsProcessorList" = None,
):
# Initialize translator before parent class
self._atm_translator = AssistantVocabTranslatorCache.get_translator(
target_tokenizer, assistant_tokenizer, assistant_model.device, target_vocab_size
)
super().__init__(
input_ids,
assistant_model,
target_tokenizer,
assistant_tokenizer,
generation_config,
model_kwargs,
inputs_tensor,
logits_processor,
)
# Track sequence lengths and previous assistant IDs
self._target_seq_len_with_candidates: int = 0
self._prev_assistant_ids: Optional[torch.LongTensor] = None
self.target_vocab_size = target_vocab_size

def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Simplified version of get_candidates that uses the translator cache for token conversion.
"""
target_input_ids = input_ids.to(self.assistant_model.device)
assistant_input_ids, num_added_tokens = self._prepare_assistant_input_ids(target_input_ids)
min_new_tokens, max_new_tokens = self._calculate_new_tokens(target_input_ids)

if max_new_tokens == 0:
return input_ids, None

self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens)
generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens)

# Ensure scores are returned
generation_args["generation_config"].output_scores = True
generation_args["generation_config"].return_dict_in_generate = True

# Generate and process outputs using translator
generation_args["logits_processor"] = self._atm_translator.logits_processors
self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args)

# Use translator to convert tokens and logits
target_candidate_ids = self._atm_translator.get_target_ids(
assistant_input_ids, target_input_ids, self._prev_assistant_ids
)
self._target_seq_len_with_candidates = target_candidate_ids.shape[-1]
target_candidate_logits = self._atm_translator.get_target_logits(assistant_candidate_logits)

return target_candidate_ids, target_candidate_logits

def _update_past_and_masks(self, assistant_input_ids: torch.LongTensor, num_added_tokens: int = 1) -> bool:
if self._prev_assistant_ids is None:
# Prepare attention mask for the first generation.
# For subsequent generations, the attention mask is updated in super()_update_past_and_masks.
self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, assistant_input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
)
return super()._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens)

def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> torch.LongTensor:
"""
Simplified token conversion that only processes new tokens.
"""
# Calculate new tokens since last call
target_seq_len = target_input_ids.shape[-1]
if self._target_seq_len_with_candidates == 0:
new_token_count = target_seq_len
else:
new_token_count = 1
target_new_ids = target_input_ids[:, -new_token_count:]

# Convert only the new tokens
target_new_text = self.target_tokenizer.batch_decode(
target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
assistant_new_ids = self.assistant_tokenizer(target_new_text, add_special_tokens=False, return_tensors="pt")[
"input_ids"
].to(self.assistant_model.device)
Comment on lines +837 to +843
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, converting only new tokens and concatenating with old assistant ids here means that sometimes the total assistant ids might not be the actual tokenization of input text, isnt' it? Since we are hitting the token boundaries and can be experiencing some discrepancies. I see in the UAG we have a small window that shift target ids before reencoding them

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp
Unlike UAG, USD has no discrepancies, as all tokens validated by the target are guaranteed to be present in the draft vocab.


# Update or initialize assistant IDs
if self._prev_assistant_ids is None:
assistant_input_ids = assistant_new_ids
else:
tokens_to_remove = self._target_seq_len_with_candidates + 1 - target_seq_len
# If the number of new tokens is greater than zero, truncate the previous assistant IDs
if tokens_to_remove > 0:
self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove]
assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1)
assistant_input_ids = assistant_input_ids.to(torch.int)

return assistant_input_ids, len(assistant_new_ids[0])


class PromptLookupCandidateGenerator(CandidateGenerator):
"""
`CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,14 +1860,15 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, suppress_tokens, device: str = "cpu"):
def __init__(self, suppress_tokens, device: str = "cpu", filter_value: float = -float("Inf")):
self.suppress_tokens = torch.tensor(list(suppress_tokens), device=device)
self.filter_value = filter_value

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens)
scores = torch.where(suppress_token_mask, -float("inf"), scores)
scores = torch.where(suppress_token_mask, self.filter_value, scores)
return scores


Expand Down
39 changes: 29 additions & 10 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
CandidateGenerator,
EarlyExitCandidateGenerator,
PromptLookupCandidateGenerator,
UniversalSpeculativeDecodingGenerator,
_crop_past_key_values,
_prepare_attention_mask,
_prepare_token_type_ids,
Expand Down Expand Up @@ -845,16 +846,33 @@ def _get_candidate_generator(
max_length=generation_config.max_length,
)
elif different_tokenizers:
candidate_generator = AssistedCandidateGeneratorDifferentTokenizers(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
target_tokenizer=target_tokenizer,
assistant_tokenizer=assistant_tokenizer,
)
match generation_config.do_sample:
case True:
candidate_generator = UniversalSpeculativeDecodingGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
target_tokenizer=target_tokenizer,
assistant_tokenizer=assistant_tokenizer,
# required in the case that self.config.vocab_size is different from the length of target_tokenizer.get_vocab()
target_vocab_size=self.config.vocab_size,
)
case False:
candidate_generator = AssistedCandidateGeneratorDifferentTokenizers(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
target_tokenizer=target_tokenizer,
assistant_tokenizer=assistant_tokenizer,
)
case _:
raise ValueError(f"Invalid value for `do_sample`: {generation_config.do_sample}")
else:
candidate_generator = AssistedCandidateGenerator(
input_ids=input_ids,
Expand Down Expand Up @@ -4262,6 +4280,7 @@ def _assisted_decoding(

# 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
candidate_input_ids = candidate_input_ids.to(self.device)

candidate_input_ids = candidate_input_ids.to(self.device)
if candidate_logits is not None:
Expand Down
Loading