Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix prepare + apply #7

Merged
merged 29 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
6097a8d
fix prepare + apply
jmamou Dec 8, 2024
71562fc
move to cpu
jmamou Dec 8, 2024
3b4e9da
simplity suppress_tokens
jmamou Dec 9, 2024
1dcdae4
fix bugs and refacatoring
jmamou Dec 9, 2024
10d1e56
device move
jmamou Dec 9, 2024
f9a260f
handle self.config.vocab_size > len(target_tokenizer.get_vocab())
jmamou Dec 10, 2024
0d3310d
no need to normalize in candidate_generator
jmamou Dec 10, 2024
98cd50b
address Nadav's comments + minor
jmamou Dec 11, 2024
8260624
optimize device move + SuppressTokensLogitsProcessor
jmamou Dec 11, 2024
ff7977e
AssistantToTargetTranslator, SuppressTokensLogitsProcessor and tokeni…
jmamou Dec 12, 2024
38d81b1
padding size
jmamou Dec 12, 2024
6a7d3b3
padding improvement
jmamou Dec 12, 2024
e4e53b9
fix and simplify get_target_logits
jmamou Dec 12, 2024
a19a9de
renaming in get_target_logits
jmamou Dec 12, 2024
c4e4186
minor
jmamou Dec 15, 2024
0ec0788
add filter_value and suppress_tokens_id
jmamou Dec 15, 2024
200f7a0
style + rename
jmamou Dec 15, 2024
95bfa2c
remove TODO
jmamou Dec 16, 2024
1cbc871
restore original SelectTokensLogitsProcessor with modification
jmamou Dec 16, 2024
4a94849
fix style
jmamou Dec 16, 2024
f1b6b08
fix _update_past_and_masks and optimize code
jmamou Dec 16, 2024
df68533
remove assistant_vocab_size arg
jmamou Dec 16, 2024
35e354a
fix attention_mask
jmamou Dec 16, 2024
a558bd0
call _prepare_attention_mask also if not has_past_key_values
jmamou Dec 16, 2024
5c3ad58
handling attention mask for first generation
jmamou Dec 17, 2024
811a4e5
comment
jmamou Dec 17, 2024
2dcc9ed
restore test
jmamou Dec 17, 2024
f2be0da
remove SelectTokensLogitsProcessor
jmamou Dec 17, 2024
83b8250
_update_past_and_masks implementation for USD
jmamou Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 55 additions & 44 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from ..cache_utils import DynamicCache
from ..pytorch_utils import isin_mps_friendly
from .logits_process import (
LogitNormalization,
LogitsProcessorList,
MinLengthLogitsProcessor,
SuppressTokensLogitsProcessor,
Expand Down Expand Up @@ -565,34 +564,38 @@ class AssistantToTargetTranslator:
Translate the assistant into the target universe.
"""

def __init__(self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase"):
def __init__(self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", assistant_model_device, target_vocab_size: int,
assistant_vocab_size:int):
self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer
self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer
self._assistant_to_target_input_ids: dict[int, int] = self._get_assistant_to_target_input_ids()
self.suppress_input_ids: list[int] = self._get_suppress_input_ids()
self.assistant_model_device = assistant_model_device
jmamou marked this conversation as resolved.
Show resolved Hide resolved
target_tokenizer_vocab_size: int = len(target_tokenizer.get_vocab())
# paddind is required in the case that target_vocab_size is bigger than set(target_tokenizer.get_vocab().keys())
if target_tokenizer_vocab_size < target_vocab_size:
jmamou marked this conversation as resolved.
Show resolved Hide resolved
self._padding_size = target_vocab_size - target_tokenizer_vocab_size
self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids()
self._suppress_input_ids: list[int] = self._get_suppress_input_ids()
self.logits_processors: LogitsProcessorList = LogitsProcessorList(
[
SuppressTokensLogitsProcessor(self.suppress_input_ids),
LogitNormalization(),
jmamou marked this conversation as resolved.
Show resolved Hide resolved
SuppressTokensLogitsProcessor(self._suppress_input_ids, assistant_vocab_size, self.assistant_model_device)
]
)

def _get_assistant_to_target_input_ids(self) -> dict[int, int]:
"""
Get a mapping from assistant tokens to target tokens based on vocabularies.
"""
def _get_assistant_to_target_input_ids(self):
target_vocab = self._target_tokenizer.get_vocab()
assistant_vocab = self._assistant_tokenizer.get_vocab()
return {
assistant_vocab[tok]: target_vocab[tok] for tok in set(target_vocab.keys()) & set(assistant_vocab.keys())
}

max_assistant_index = max(assistant_vocab.values())
assistant_to_target_input_ids = torch.full((max_assistant_index+1,), -1, dtype=int) # -1 means not in target vocab
for tok, idx in assistant_vocab.items():
if tok in target_vocab:
assistant_to_target_input_ids[idx] = target_vocab[tok]
return assistant_to_target_input_ids.to(self.assistant_model_device)

def _get_suppress_input_ids(self) -> list[int]:
"""
Get the input ids that are in the assistant vocab but not in the target vocab.
"""
assistant_vocab = self._assistant_tokenizer.get_vocab()
return list(set(assistant_vocab.values()) - set(self._assistant_to_target_input_ids.keys()))
return torch.where(self._assistant_to_target_input_ids==-1)[0]

def get_target_ids(
self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor
Expand All @@ -602,33 +605,30 @@ def get_target_ids(
Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens.
Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids.
"""
device = assistant_candidate_ids.device
target_candidate_ids = (
assistant_candidate_ids[0, -(len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]) :]
.cpu()
.apply_(lambda x: self._assistant_to_target_input_ids.get(x, x))
.to(device)
)
return torch.cat((target_input_ids, target_candidate_ids.unsqueeze(0)), dim=1)

num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]
if num_new_tokens == 0:
return target_input_ids
else:
transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens :]]
return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1)

def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor:
"""
Return the target logits that correspond to the assistant logits.
"""
device = assistant_logits.device
target_vocab_size: int = len(self._target_tokenizer.get_vocab())
target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], target_vocab_size)
target_logits: torch.FloatTensor = torch.full(target_shape, -float("inf")).to(device)
target_logits: torch.FloatTensor = torch.full(target_shape, -float("inf")).to(self.assistant_model_device)
assistant_logits_supported_mask: torch.BoolTensor = assistant_logits > -float("inf")
assistant_logits_supported_indices: torch.IntTensor = assistant_logits_supported_mask.nonzero(as_tuple=True)[
-1
]
target_logits_supported_indices: torch.IntTensor = (
assistant_logits_supported_indices.cpu()
.apply_(lambda x: self._assistant_to_target_input_ids[x])
.to(device)
)
target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_logits_supported_indices]
target_logits[..., target_logits_supported_indices] = assistant_logits[..., assistant_logits_supported_mask]
if hasattr(self, '_padding_size'):
padding = torch.full((target_logits.size(0), target_logits.size(1), self._padding_size), -float("inf")).to(self.assistant_model_device)
target_logits = torch.cat((target_logits, padding), dim=2)
return target_logits


Expand All @@ -643,7 +643,7 @@ class AssistantVocabTranslatorCache:

@classmethod
def get_translator(
cls, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase"
cls, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", assistant_model_device, target_vocab_size: int, assistant_vocab_size: int
) -> AssistantToTargetTranslator:
with cls._lock:
assistant_dict = cls._cache.get(target_tokenizer)
Expand All @@ -653,7 +653,7 @@ def get_translator(

mapping = assistant_dict.get(assistant_tokenizer)
if mapping is None:
mapping = AssistantToTargetTranslator(target_tokenizer, assistant_tokenizer)
mapping = AssistantToTargetTranslator(target_tokenizer, assistant_tokenizer, assistant_model_device, target_vocab_size, assistant_vocab_size)
assistant_dict[assistant_tokenizer] = mapping

return mapping
Expand Down Expand Up @@ -692,11 +692,13 @@ def __init__(
assistant_tokenizer: "PreTrainedTokenizerBase",
generation_config: "GenerationConfig",
model_kwargs: Dict,
target_vocab_size: int,
inputs_tensor: Optional[torch.Tensor] = None,
logits_processor: "LogitsProcessorList" = None,
):
# Initialize translator before parent class
self._atm_translator = AssistantVocabTranslatorCache.get_translator(target_tokenizer, assistant_tokenizer)
self._atm_translator = AssistantVocabTranslatorCache.get_translator(target_tokenizer, assistant_tokenizer, assistant_model.device,
target_vocab_size, assistant_model.config.vocab_size)
super().__init__(
input_ids,
assistant_model,
Expand All @@ -708,8 +710,9 @@ def __init__(
logits_processor,
)
# Track sequence lengths and previous assistant IDs
self._prev_target_seq_len: int = 0
self._target_seq_len_with_candidates: int = 0
self._prev_assistant_ids: Optional[torch.LongTensor] = None
self.target_vocab_size = target_vocab_size

def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Expand All @@ -732,15 +735,16 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
generation_args["generation_config"].return_dict_in_generate = True

# Generate and process outputs using translator
generation_args['logits_processor'] = self._atm_translator.logits_processors
assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values

candidate_logits = torch.stack(assistant_output.scores, dim=1)

# Use translator to convert tokens and logits
candidate_ids = assistant_output.sequences
candidate_logits = self._atm_translator.logits_processors(input_ids=candidate_ids, scores=candidate_logits)
target_ids = self._atm_translator.get_target_ids(assistant_input_ids, target_input_ids, candidate_ids)
self._prev_assistant_ids = assistant_output.sequences
target_ids = self._atm_translator.get_target_ids(assistant_input_ids, target_input_ids, self._prev_assistant_ids)
self._target_seq_len_with_candidates = target_ids.shape[-1]
target_logits = self._atm_translator.get_target_logits(candidate_logits)

return target_ids, target_logits
Expand All @@ -751,9 +755,11 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to
"""
# Calculate new tokens since last call
target_seq_len = target_input_ids.shape[-1]
new_token_count = target_seq_len - self._prev_target_seq_len
if self._target_seq_len_with_candidates == 0:
new_token_count = target_seq_len
else:
new_token_count = 1
target_new_ids = target_input_ids[:, -new_token_count:]
self._prev_target_seq_len = target_seq_len

# Convert only the new tokens
target_new_text = self.target_tokenizer.batch_decode(
Expand All @@ -765,11 +771,16 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to

# Update or initialize assistant IDs
if self._prev_assistant_ids is None:
self._prev_assistant_ids = assistant_new_ids
assistant_input_ids = assistant_new_ids
else:
self._prev_assistant_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1)

return self._prev_assistant_ids
tokens_to_remove = self._target_seq_len_with_candidates+1-target_seq_len
# If the number of new tokens is greater than zero, truncate the previous assistant IDs
if tokens_to_remove > 0:
self._prev_assistant_ids = self._prev_assistant_ids[:,:-tokens_to_remove]
assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1)
assistant_input_ids = assistant_input_ids.to(torch.int)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the documentation, cat operates on arrays of the same type. Wdyt about ensuring that self._prev_assistant_ids and assistant_new_ids are already of torch.int type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean adding before cat

self._prev_assistant_ids = self._prev_assistant_ids.to(torch.int)
assistant_new_ids = assistant_new_ids.to(torch.int)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wdyt about ensuring we only assign torch.int to self._prev_assistant_ids and assistant_new_ids in the first place—so that we never need to cast them into torch.int?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we get all the IDs from the tokenizer and their type is int. Do you think that it is necessary to ensure they are of int type?


return assistant_input_ids


class PromptLookupCandidateGenerator(CandidateGenerator):
Expand Down
11 changes: 5 additions & 6 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,15 +1860,14 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, suppress_tokens, device: str = "cpu"):
self.suppress_tokens = torch.tensor(list(suppress_tokens), device=device)
def __init__(self, suppress_tokens, assistant_vocab_size, assistant_model_device):
suppress_tokens = suppress_tokens
jmamou marked this conversation as resolved.
Show resolved Hide resolved
vocab_tensor = torch.arange(assistant_vocab_size, device=assistant_model_device)
self.suppress_token_mask = isin_mps_friendly(vocab_tensor, suppress_tokens)

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens.to(scores.device))
scores = torch.where(suppress_token_mask, -float("inf"), scores)
return scores
return scores.masked_fill_(self.suppress_token_mask, -float("inf"))


class WhisperTimeStampLogitsProcessor(LogitsProcessor):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,7 @@ def _get_candidate_generator(
logits_processor=logits_processor,
target_tokenizer=target_tokenizer,
assistant_tokenizer=assistant_tokenizer,
target_vocab_size=self.config.vocab_size # required in the case that self.config.vocab_size is different from the length of target_tokenizer.get_vocab()
jmamou marked this conversation as resolved.
Show resolved Hide resolved
)
case False:
candidate_generator = AssistedCandidateGeneratorDifferentTokenizers(
Expand Down
2 changes: 1 addition & 1 deletion tests/generation/test_candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_get_assistant_to_target_input_ids(self):
def test_get_suppress_input_ids(self):
"""Test the suppression of assistant input IDs not present in the target vocabulary."""
expected_suppress_ids = [4]
actual_suppress_ids = self.translator.suppress_input_ids
actual_suppress_ids = self.translator._suppress_input_ids
self.assertEqual(actual_suppress_ids, expected_suppress_ids)

def test_get_target_ids(self):
Expand Down