Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OLD] New PR: #35029. [[Universal Speculative Decoding CandidateGenerator]] #34760

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 87 additions & 89 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,45 +194,15 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
vocabulary_size)` containing the logits associated to each candidate.
"""
input_ids = input_ids.to(self.assistant_model.device)

# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
new_cur_len = input_ids.shape[-1]
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
# Calculate new tokens to generate
min_new_tokens, max_new_tokens = self._calculate_new_tokens(input_ids)
if max_new_tokens == 0:
return input_ids, None

# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cache_size = new_cur_len - 1
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
) # the assistant does not have the token after the last match, hence the -1

self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)

# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
self.input_ids_key: input_ids,
"min_new_tokens": min_new_tokens,
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
}

assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)

# 3. Update variables for the next round of candidate generation
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values

# 4. Prepare variables for output
candidate_logits = torch.stack(assistant_output.scores, dim=1)
candidate_ids = assistant_output.sequences
# Update past key values and masks
self._update_past_and_masks(input_ids)
# Generate candidates
generation_args = self._prepare_generation_args(input_ids, min_new_tokens, max_new_tokens)
candidate_ids, candidate_logits = self._generate_candidates(generation_args)
return candidate_ids, candidate_logits

def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
Expand Down Expand Up @@ -261,6 +231,45 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F
else:
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)

def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> Tuple[int, int]:
"""Calculate the minimum and maximum number of new tokens to generate."""
new_cur_len = input_ids.shape[-1]
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
return min_new_tokens, max_new_tokens

def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool:
"""Update past key values and attention masks for subsequent generation rounds."""
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
)
self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1])
return has_past_key_values

def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict:
"""Prepare arguments for the generation call."""
return {
self.input_ids_key: input_ids,
"min_new_tokens": min_new_tokens,
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
}

def _generate_candidates(self, generation_args: Dict) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""Generate candidate sequences using the assistant model."""
assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
candidate_logits = torch.stack(assistant_output.scores, dim=1)
candidate_ids = assistant_output.sequences
return candidate_ids, candidate_logits


class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
"""
Expand Down Expand Up @@ -310,6 +319,8 @@ def __init__(

self.target_tokenizer = target_tokenizer
self.assistant_tokenizer = assistant_tokenizer
self.prev_target_ids = None
self.prev_tokens = None
self.prev_assistant_ids = None
self.target_lookbehind = assistant_model.generation_config.target_lookbehind
self.assistant_lookbehind = assistant_model.generation_config.assistant_lookbehind
Expand Down Expand Up @@ -440,27 +451,50 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
return input_ids, None

input_ids = input_ids.to(self.assistant_model.device)
remove_from_pkv = 0

assistant_input_ids, remove_from_pkv = self._prepare_assistant_input_ids(input_ids)
self.prev_assistant_ids = assistant_input_ids

min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - assistant_input_ids.shape[-1]), 0)

self._update_past_and_masks(assistant_input_ids, remove_from_pkv)
generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens)
self.assistant_kwargs.pop("attention_mask", None)

assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
new_target_ids = self._process_assistant_outputs(input_ids, assistant_output.sequences, assistant_input_ids)

# Update state
self.prev_target_ids = input_ids
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
self.prev_tokens = assistant_output.sequences

if input_ids.shape[1] >= new_target_ids.shape[1]:
return input_ids, None

return new_target_ids, None

def _prepare_assistant_input_ids(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, int]:
"""Converts target input IDs to assistant input IDs, handling discrepancies."""
convert_kwargs = {
"source_tokenizer": self.target_tokenizer,
"destination_tokenizer": self.assistant_tokenizer,
}
remove_from_pkv = 0

# Since re-encoding the tokens may result in tokenization discrepancies, we use 2 look behind values
# (one for each conversion) which mark where to start looking for the overlap between the
# source and target encodings, to ensure the new tokens include the correct prompt suffix.
if self.prev_assistant_ids is not None and input_ids.shape[1] > self.target_lookbehind:
if self.prev_tokens is not None and self.prev_target_ids.shape[1] > self.target_lookbehind:
# input_ids contains all target prompt input ids and some new target input ids
start_index_in_target_window = input_ids.shape[1] - self.target_lookbehind
start_index_in_target_window = self.prev_target_ids.shape[1] - self.target_lookbehind

new_assistant_ids = self.convert_source_tokens_to_target_tokens(
input_ids[:, start_index_in_target_window:], **convert_kwargs
)
prompt_use_length = new_assistant_ids.shape[1]
prompt_use = self.prev_assistant_ids[:, -prompt_use_length:]

discrepancy_length, new_tokens_only, discrepancy_only = (
AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt_use, new_assistant_ids)
discrepancy_length, new_tokens_only, discrepancy_only = self._get_tokens_diag(
prompt_use, new_assistant_ids
)
assistant_input_ids = self.prev_assistant_ids

Expand All @@ -481,58 +515,29 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
else:
# edge case: in case of no intersection between prompt and new_assistant_ids
assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1)

else:
assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs)
self.prev_target_ids = input_ids

self.prev_assistant_ids = assistant_input_ids
new_cur_len = assistant_input_ids.shape[-1]
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)

# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cache_size = new_cur_len - 1 - remove_from_pkv
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
) # the assistant does not have the token after the last match, hence the -1

self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)

# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
self.input_ids_key: assistant_input_ids,
"min_new_tokens": min_new_tokens,
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
}

self.assistant_kwargs.pop("attention_mask", None)

assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
return assistant_input_ids, remove_from_pkv

def _process_assistant_outputs(
self, input_ids: torch.LongTensor, assistant_sequences: torch.LongTensor, assistant_input_ids: torch.LongTensor
) -> torch.LongTensor:
"""Processes assistant outputs to obtain target input IDs."""
num_prev_assistant = self.prev_assistant_ids.shape[1]
start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind
if start_assistant_look_index < 0:
start_assistant_look_index = 0

new_target_ids_from_window = self.convert_source_tokens_to_target_tokens(
assistant_output.sequences[:, start_assistant_look_index:],
assistant_sequences[:, start_assistant_look_index:],
source_tokenizer=self.assistant_tokenizer,
destination_tokenizer=self.target_tokenizer,
)
target_prompt_use_length = new_target_ids_from_window.shape[1]

target_prompt_use = input_ids[:, -target_prompt_use_length:]

_, target_new_tokens_only, _ = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
target_prompt_use, new_target_ids_from_window
)
_, target_new_tokens_only, _ = self._get_tokens_diag(target_prompt_use, new_target_ids_from_window)

new_target_ids = input_ids

Expand All @@ -546,14 +551,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
if hasattr(self.generation_config, "max_length"):
new_target_ids = new_target_ids[:, : self.generation_config.max_length]

# 3. Update variables for the next round of candidate generation
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values

# 4. Prepare variables for output
if input_ids.shape[1] >= new_target_ids.shape[1]:
return input_ids, None

return new_target_ids, None
return new_target_ids


class PromptLookupCandidateGenerator(CandidateGenerator):
Expand Down
43 changes: 43 additions & 0 deletions tests/generation/test_candidate_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import unittest

import numpy as np

from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers


class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
def test_no_intersection(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[4, 5, 6]])
result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens)
self.assertEqual(result, (None, None, None))

def test_complete_overlap(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))

def test_partial_overlap(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[2, 3, 4, 5]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))

def test_no_new_tokens(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[1, 2, 3]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))
39 changes: 0 additions & 39 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
WatermarkDetector,
WatermarkingConfig,
)
from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers
from transformers.generation.utils import _speculative_sampling


Expand Down Expand Up @@ -4274,41 +4273,3 @@ def test_generate_from_inputs_embeds_with_bos_token_id_is_none(self):
# bos_token_id is required when no input ids nor inputs_embeds is passed
with self.assertRaises(ValueError):
model.generate(max_length=20, bos_token_id=None)


class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
def test_no_intersection(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[4, 5, 6]])
result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens)
self.assertEqual(result, (None, None, None))

def test_complete_overlap(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))

def test_partial_overlap(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[2, 3, 4, 5]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))

def test_no_new_tokens(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[1, 2, 3]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))