From 6097a8d14dd0bbd9f23fa660babddce5afdee769 Mon Sep 17 00:00:00 2001 From: jmamou Date: Sun, 8 Dec 2024 06:16:08 -0800 Subject: [PATCH 01/29] fix prepare + apply --- .../generation/candidate_generator.py | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 25148b3a6b837c..19999e1cc56661 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -569,6 +569,7 @@ def __init__(self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokeni self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer self._assistant_to_target_input_ids: dict[int, int] = self._get_assistant_to_target_input_ids() + self._assistant_to_target_input_ids_mapping = self._get_assistant_to_target_input_ids_mapping() self.suppress_input_ids: list[int] = self._get_suppress_input_ids() self.logits_processors: LogitsProcessorList = LogitsProcessorList( [ @@ -587,6 +588,9 @@ def _get_assistant_to_target_input_ids(self) -> dict[int, int]: assistant_vocab[tok]: target_vocab[tok] for tok in set(target_vocab.keys()) & set(assistant_vocab.keys()) } + def _get_assistant_to_target_input_ids_mapping(self): + return torch.tensor([self._assistant_to_target_input_ids.get(x, x) for x in range(max(self._assistant_to_target_input_ids.keys()) + 1)]) + def _get_suppress_input_ids(self) -> list[int]: """ Get the input ids that are in the assistant vocab but not in the target vocab. @@ -603,13 +607,8 @@ def get_target_ids( Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids. """ device = assistant_candidate_ids.device - target_candidate_ids = ( - assistant_candidate_ids[0, -(len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]) :] - .cpu() - .apply_(lambda x: self._assistant_to_target_input_ids.get(x, x)) - .to(device) - ) - return torch.cat((target_input_ids, target_candidate_ids.unsqueeze(0)), dim=1) + transformed_slice = self._assistant_to_target_input_ids_mapping[assistant_candidate_ids[0, -(len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]) :].cpu()].to(device) + return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor: """ @@ -623,11 +622,7 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT assistant_logits_supported_indices: torch.IntTensor = assistant_logits_supported_mask.nonzero(as_tuple=True)[ -1 ] - target_logits_supported_indices: torch.IntTensor = ( - assistant_logits_supported_indices.cpu() - .apply_(lambda x: self._assistant_to_target_input_ids[x]) - .to(device) - ) + target_logits_supported_indices = self._assistant_to_target_input_ids_mapping[assistant_logits_supported_indices].to(device) target_logits[..., target_logits_supported_indices] = assistant_logits[..., assistant_logits_supported_mask] return target_logits @@ -739,6 +734,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, # Use translator to convert tokens and logits candidate_ids = assistant_output.sequences + self._prev_assistant_ids = candidate_ids candidate_logits = self._atm_translator.logits_processors(input_ids=candidate_ids, scores=candidate_logits) target_ids = self._atm_translator.get_target_ids(assistant_input_ids, target_input_ids, candidate_ids) target_logits = self._atm_translator.get_target_logits(candidate_logits) @@ -751,9 +747,12 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to """ # Calculate new tokens since last call target_seq_len = target_input_ids.shape[-1] - new_token_count = target_seq_len - self._prev_target_seq_len + if self._prev_target_seq_len == 0: + new_token_count = target_seq_len + else: + new_token_count = 1 target_new_ids = target_input_ids[:, -new_token_count:] - self._prev_target_seq_len = target_seq_len + # Convert only the new tokens target_new_text = self.target_tokenizer.batch_decode( @@ -765,11 +764,13 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to # Update or initialize assistant IDs if self._prev_assistant_ids is None: - self._prev_assistant_ids = assistant_new_ids + new_assistant_ids = assistant_new_ids else: - self._prev_assistant_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) + new_assistant_ids = torch.cat([self._prev_assistant_ids[:,:-(target_seq_len-self._prev_target_seq_len)], assistant_new_ids], dim=-1) - return self._prev_assistant_ids + self._prev_target_seq_len = target_seq_len + + return new_assistant_ids class PromptLookupCandidateGenerator(CandidateGenerator): From 71562fcdfdee06c81cd3bf7f273d123cf7b12fdc Mon Sep 17 00:00:00 2001 From: jmamou Date: Sun, 8 Dec 2024 06:41:57 -0800 Subject: [PATCH 02/29] move to cpu --- src/transformers/generation/candidate_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 19999e1cc56661..1fb01f52c19ddc 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -621,7 +621,7 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT assistant_logits_supported_mask: torch.BoolTensor = assistant_logits > -float("inf") assistant_logits_supported_indices: torch.IntTensor = assistant_logits_supported_mask.nonzero(as_tuple=True)[ -1 - ] + ].cpu() target_logits_supported_indices = self._assistant_to_target_input_ids_mapping[assistant_logits_supported_indices].to(device) target_logits[..., target_logits_supported_indices] = assistant_logits[..., assistant_logits_supported_mask] return target_logits From 3b4e9daed3aaef5201fcbb88bf54786e36021543 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 9 Dec 2024 01:37:48 -0800 Subject: [PATCH 03/29] simplity suppress_tokens --- src/transformers/generation/logits_process.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 5fcd35c921af86..e0f27c98a6bff5 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1860,8 +1860,8 @@ class SuppressTokensLogitsProcessor(LogitsProcessor): ``` """ - def __init__(self, suppress_tokens, device: str = "cpu"): - self.suppress_tokens = torch.tensor(list(suppress_tokens), device=device) + def __init__(self, suppress_tokens): + self.suppress_tokens = suppress_tokens @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: From 1dcdae46252aba9307fd11da3f86992e8574c24e Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 9 Dec 2024 06:37:12 -0800 Subject: [PATCH 04/29] fix bugs and refacatoring --- .../generation/candidate_generator.py | 61 ++++++++++--------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 1fb01f52c19ddc..0f9dc6ceea89a9 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -568,8 +568,7 @@ class AssistantToTargetTranslator: def __init__(self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase"): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer - self._assistant_to_target_input_ids: dict[int, int] = self._get_assistant_to_target_input_ids() - self._assistant_to_target_input_ids_mapping = self._get_assistant_to_target_input_ids_mapping() + self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids() self.suppress_input_ids: list[int] = self._get_suppress_input_ids() self.logits_processors: LogitsProcessorList = LogitsProcessorList( [ @@ -578,25 +577,21 @@ def __init__(self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokeni ] ) - def _get_assistant_to_target_input_ids(self) -> dict[int, int]: - """ - Get a mapping from assistant tokens to target tokens based on vocabularies. - """ + def _get_assistant_to_target_input_ids(self): target_vocab = self._target_tokenizer.get_vocab() assistant_vocab = self._assistant_tokenizer.get_vocab() - return { - assistant_vocab[tok]: target_vocab[tok] for tok in set(target_vocab.keys()) & set(assistant_vocab.keys()) - } - - def _get_assistant_to_target_input_ids_mapping(self): - return torch.tensor([self._assistant_to_target_input_ids.get(x, x) for x in range(max(self._assistant_to_target_input_ids.keys()) + 1)]) - + max_assistant_index = max(assistant_vocab.values()) + assistant_to_target_input_ids = torch.full((max_assistant_index+1,), -1, dtype=int) # -1 means not in target vocab + for tok, idx in assistant_vocab.items(): + if tok in target_vocab: + assistant_to_target_input_ids[idx] = target_vocab[tok] + return assistant_to_target_input_ids + def _get_suppress_input_ids(self) -> list[int]: """ Get the input ids that are in the assistant vocab but not in the target vocab. """ - assistant_vocab = self._assistant_tokenizer.get_vocab() - return list(set(assistant_vocab.values()) - set(self._assistant_to_target_input_ids.keys())) + return torch.where(self._assistant_to_target_input_ids==-1)[0] def get_target_ids( self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor @@ -606,9 +601,15 @@ def get_target_ids( Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens. Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids. """ - device = assistant_candidate_ids.device - transformed_slice = self._assistant_to_target_input_ids_mapping[assistant_candidate_ids[0, -(len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]) :].cpu()].to(device) - return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) + + i = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1] + if i == 0: + return target_input_ids + else: + #assert len(assistant_candidate_ids[0]) > assistant_input_ids.shape[1] + transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -(len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]) :]] + #assert torch.all(transformed_slice != -1) + return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor: """ @@ -622,7 +623,7 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT assistant_logits_supported_indices: torch.IntTensor = assistant_logits_supported_mask.nonzero(as_tuple=True)[ -1 ].cpu() - target_logits_supported_indices = self._assistant_to_target_input_ids_mapping[assistant_logits_supported_indices].to(device) + target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_logits_supported_indices].to(device) target_logits[..., target_logits_supported_indices] = assistant_logits[..., assistant_logits_supported_mask] return target_logits @@ -703,7 +704,7 @@ def __init__( logits_processor, ) # Track sequence lengths and previous assistant IDs - self._prev_target_seq_len: int = 0 + self._candidates_target_seq_len: int = 0 self._prev_assistant_ids: Optional[torch.LongTensor] = None def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: @@ -727,6 +728,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, generation_args["generation_config"].return_dict_in_generate = True # Generate and process outputs using translator + generation_args['logits_processor'] = self._atm_translator.logits_processors assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs) self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values @@ -735,8 +737,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, # Use translator to convert tokens and logits candidate_ids = assistant_output.sequences self._prev_assistant_ids = candidate_ids - candidate_logits = self._atm_translator.logits_processors(input_ids=candidate_ids, scores=candidate_logits) target_ids = self._atm_translator.get_target_ids(assistant_input_ids, target_input_ids, candidate_ids) + self._candidates_target_seq_len = target_ids.shape[-1] target_logits = self._atm_translator.get_target_logits(candidate_logits) return target_ids, target_logits @@ -747,12 +749,11 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to """ # Calculate new tokens since last call target_seq_len = target_input_ids.shape[-1] - if self._prev_target_seq_len == 0: + if self._candidates_target_seq_len == 0: new_token_count = target_seq_len else: new_token_count = 1 target_new_ids = target_input_ids[:, -new_token_count:] - # Convert only the new tokens target_new_text = self.target_tokenizer.batch_decode( @@ -764,13 +765,15 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to # Update or initialize assistant IDs if self._prev_assistant_ids is None: - new_assistant_ids = assistant_new_ids + assistant_input_ids = assistant_new_ids else: - new_assistant_ids = torch.cat([self._prev_assistant_ids[:,:-(target_seq_len-self._prev_target_seq_len)], assistant_new_ids], dim=-1) - - self._prev_target_seq_len = target_seq_len - - return new_assistant_ids + i = self._candidates_target_seq_len+1-target_seq_len + if i > 0: + self._prev_assistant_ids = self._prev_assistant_ids[:,:-i] + assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) + assistant_input_ids = assistant_input_ids.to(torch.int) + + return assistant_input_ids class PromptLookupCandidateGenerator(CandidateGenerator): From 10d1e56ff8b473863f9ce4532e0fe2f5e1f6b4f1 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 9 Dec 2024 06:43:12 -0800 Subject: [PATCH 05/29] device move --- src/transformers/generation/candidate_generator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 0f9dc6ceea89a9..72cf147af86266 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -606,8 +606,9 @@ def get_target_ids( if i == 0: return target_input_ids else: + device = assistant_candidate_ids.device #assert len(assistant_candidate_ids[0]) > assistant_input_ids.shape[1] - transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -(len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]) :]] + transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -(len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]) :].cpu()].to(device) #assert torch.all(transformed_slice != -1) return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) From f9a260fed64444229771b2537f46fb32e329ce20 Mon Sep 17 00:00:00 2001 From: jmamou Date: Tue, 10 Dec 2024 04:09:15 -0800 Subject: [PATCH 06/29] handle self.config.vocab_size > len(target_tokenizer.get_vocab()) --- .../generation/candidate_generator.py | 19 +++++++++++++------ src/transformers/generation/utils.py | 1 + 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 72cf147af86266..4e95cbb4261104 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -565,9 +565,13 @@ class AssistantToTargetTranslator: Translate the assistant into the target universe. """ - def __init__(self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase"): + def __init__(self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", target_vocab_size: int): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer + target_tokenizer_vocab_size: int = len(target_tokenizer.get_vocab()) + # paddind is required in the case that target_vocab_size is bigger than set(target_tokenizer.get_vocab().keys()) + if target_tokenizer_vocab_size < target_vocab_size: + self._padding_size = target_vocab_size - target_tokenizer_vocab_size self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids() self.suppress_input_ids: list[int] = self._get_suppress_input_ids() self.logits_processors: LogitsProcessorList = LogitsProcessorList( @@ -607,9 +611,7 @@ def get_target_ids( return target_input_ids else: device = assistant_candidate_ids.device - #assert len(assistant_candidate_ids[0]) > assistant_input_ids.shape[1] transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -(len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]) :].cpu()].to(device) - #assert torch.all(transformed_slice != -1) return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor: @@ -626,6 +628,9 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT ].cpu() target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_logits_supported_indices].to(device) target_logits[..., target_logits_supported_indices] = assistant_logits[..., assistant_logits_supported_mask] + if hasattr(self, '_padding_size'): + padding = torch.full((target_logits.size(0), target_logits.size(1), self._padding_size), 0.0).to(device) + target_logits = torch.cat((target_logits, padding), dim=2) return target_logits @@ -640,7 +645,7 @@ class AssistantVocabTranslatorCache: @classmethod def get_translator( - cls, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase" + cls, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", target_vocab_size: int ) -> AssistantToTargetTranslator: with cls._lock: assistant_dict = cls._cache.get(target_tokenizer) @@ -650,7 +655,7 @@ def get_translator( mapping = assistant_dict.get(assistant_tokenizer) if mapping is None: - mapping = AssistantToTargetTranslator(target_tokenizer, assistant_tokenizer) + mapping = AssistantToTargetTranslator(target_tokenizer, assistant_tokenizer, target_vocab_size) assistant_dict[assistant_tokenizer] = mapping return mapping @@ -689,11 +694,12 @@ def __init__( assistant_tokenizer: "PreTrainedTokenizerBase", generation_config: "GenerationConfig", model_kwargs: Dict, + target_vocab_size: int, inputs_tensor: Optional[torch.Tensor] = None, logits_processor: "LogitsProcessorList" = None, ): # Initialize translator before parent class - self._atm_translator = AssistantVocabTranslatorCache.get_translator(target_tokenizer, assistant_tokenizer) + self._atm_translator = AssistantVocabTranslatorCache.get_translator(target_tokenizer, assistant_tokenizer, target_vocab_size) super().__init__( input_ids, assistant_model, @@ -707,6 +713,7 @@ def __init__( # Track sequence lengths and previous assistant IDs self._candidates_target_seq_len: int = 0 self._prev_assistant_ids: Optional[torch.LongTensor] = None + self.target_vocab_size = target_vocab_size def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: """ diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7a9d78168ac903..7f319d59d14ae0 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -858,6 +858,7 @@ def _get_candidate_generator( logits_processor=logits_processor, target_tokenizer=target_tokenizer, assistant_tokenizer=assistant_tokenizer, + target_vocab_size=self.config.vocab_size # required in the case that self.config.vocab_size is different from the length of target_tokenizer.get_vocab() ) case False: candidate_generator = AssistedCandidateGeneratorDifferentTokenizers( From 0d3310d34fdac8ec5edceb230a8be2bf29a3a1b4 Mon Sep 17 00:00:00 2001 From: jmamou Date: Tue, 10 Dec 2024 04:55:36 -0800 Subject: [PATCH 07/29] no need to normalize in candidate_generator --- src/transformers/generation/candidate_generator.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 4e95cbb4261104..e1e3272f376c46 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -576,8 +576,7 @@ def __init__(self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokeni self.suppress_input_ids: list[int] = self._get_suppress_input_ids() self.logits_processors: LogitsProcessorList = LogitsProcessorList( [ - SuppressTokensLogitsProcessor(self.suppress_input_ids), - LogitNormalization(), + SuppressTokensLogitsProcessor(self.suppress_input_ids) ] ) @@ -629,7 +628,7 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_logits_supported_indices].to(device) target_logits[..., target_logits_supported_indices] = assistant_logits[..., assistant_logits_supported_mask] if hasattr(self, '_padding_size'): - padding = torch.full((target_logits.size(0), target_logits.size(1), self._padding_size), 0.0).to(device) + padding = torch.full((target_logits.size(0), target_logits.size(1), self._padding_size), -float("inf")).to(device) target_logits = torch.cat((target_logits, padding), dim=2) return target_logits From 98cd50bffdac1479c15bb0eef253bd564100e994 Mon Sep 17 00:00:00 2001 From: jmamou Date: Wed, 11 Dec 2024 03:12:49 -0800 Subject: [PATCH 08/29] address Nadav's comments + minor --- .../generation/candidate_generator.py | 23 +++++++++---------- tests/generation/test_candidate_generator.py | 2 +- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index e1e3272f376c46..b092eeede4014c 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -24,7 +24,6 @@ from ..cache_utils import DynamicCache from ..pytorch_utils import isin_mps_friendly from .logits_process import ( - LogitNormalization, LogitsProcessorList, MinLengthLogitsProcessor, SuppressTokensLogitsProcessor, @@ -573,10 +572,10 @@ def __init__(self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokeni if target_tokenizer_vocab_size < target_vocab_size: self._padding_size = target_vocab_size - target_tokenizer_vocab_size self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids() - self.suppress_input_ids: list[int] = self._get_suppress_input_ids() + self._suppress_input_ids: list[int] = self._get_suppress_input_ids() self.logits_processors: LogitsProcessorList = LogitsProcessorList( [ - SuppressTokensLogitsProcessor(self.suppress_input_ids) + SuppressTokensLogitsProcessor(self._suppress_input_ids) ] ) @@ -710,7 +709,7 @@ def __init__( logits_processor, ) # Track sequence lengths and previous assistant IDs - self._candidates_target_seq_len: int = 0 + self._target_seq_len_with_candidates: int = 0 self._prev_assistant_ids: Optional[torch.LongTensor] = None self.target_vocab_size = target_vocab_size @@ -742,10 +741,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, candidate_logits = torch.stack(assistant_output.scores, dim=1) # Use translator to convert tokens and logits - candidate_ids = assistant_output.sequences - self._prev_assistant_ids = candidate_ids - target_ids = self._atm_translator.get_target_ids(assistant_input_ids, target_input_ids, candidate_ids) - self._candidates_target_seq_len = target_ids.shape[-1] + self._prev_assistant_ids = assistant_output.sequences + target_ids = self._atm_translator.get_target_ids(assistant_input_ids, target_input_ids, self._prev_assistant_ids) + self._target_seq_len_with_candidates = target_ids.shape[-1] target_logits = self._atm_translator.get_target_logits(candidate_logits) return target_ids, target_logits @@ -756,7 +754,7 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to """ # Calculate new tokens since last call target_seq_len = target_input_ids.shape[-1] - if self._candidates_target_seq_len == 0: + if self._target_seq_len_with_candidates == 0: new_token_count = target_seq_len else: new_token_count = 1 @@ -774,9 +772,10 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to if self._prev_assistant_ids is None: assistant_input_ids = assistant_new_ids else: - i = self._candidates_target_seq_len+1-target_seq_len - if i > 0: - self._prev_assistant_ids = self._prev_assistant_ids[:,:-i] + tokens_to_remove = self._target_seq_len_with_candidates+1-target_seq_len + # If the number of new tokens is greater than zero, truncate the previous assistant IDs + if tokens_to_remove > 0: + self._prev_assistant_ids = self._prev_assistant_ids[:,:-tokens_to_remove] assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) assistant_input_ids = assistant_input_ids.to(torch.int) diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index 260f92109b7bb2..dd7e427a3bfda9 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -79,7 +79,7 @@ def test_get_assistant_to_target_input_ids(self): def test_get_suppress_input_ids(self): """Test the suppression of assistant input IDs not present in the target vocabulary.""" expected_suppress_ids = [4] - actual_suppress_ids = self.translator.suppress_input_ids + actual_suppress_ids = self.translator._suppress_input_ids self.assertEqual(actual_suppress_ids, expected_suppress_ids) def test_get_target_ids(self): From 82606246ddb4e8d12a5c3c228c1051befc980278 Mon Sep 17 00:00:00 2001 From: jmamou Date: Wed, 11 Dec 2024 05:15:16 -0800 Subject: [PATCH 09/29] optimize device move + SuppressTokensLogitsProcessor --- .../generation/candidate_generator.py | 31 ++++++++++--------- src/transformers/generation/logits_process.py | 11 +++---- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index b092eeede4014c..437241d131adc5 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -564,9 +564,11 @@ class AssistantToTargetTranslator: Translate the assistant into the target universe. """ - def __init__(self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", target_vocab_size: int): + def __init__(self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", assistant_model_device, target_vocab_size: int, + assistant_vocab_size:int): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer + self.assistant_model_device = assistant_model_device target_tokenizer_vocab_size: int = len(target_tokenizer.get_vocab()) # paddind is required in the case that target_vocab_size is bigger than set(target_tokenizer.get_vocab().keys()) if target_tokenizer_vocab_size < target_vocab_size: @@ -575,7 +577,7 @@ def __init__(self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokeni self._suppress_input_ids: list[int] = self._get_suppress_input_ids() self.logits_processors: LogitsProcessorList = LogitsProcessorList( [ - SuppressTokensLogitsProcessor(self._suppress_input_ids) + SuppressTokensLogitsProcessor(self._suppress_input_ids, assistant_vocab_size, self.assistant_model_device) ] ) @@ -587,7 +589,7 @@ def _get_assistant_to_target_input_ids(self): for tok, idx in assistant_vocab.items(): if tok in target_vocab: assistant_to_target_input_ids[idx] = target_vocab[tok] - return assistant_to_target_input_ids + return assistant_to_target_input_ids.to(self.assistant_model_device) def _get_suppress_input_ids(self) -> list[int]: """ @@ -604,30 +606,28 @@ def get_target_ids( Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids. """ - i = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1] - if i == 0: + num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1] + if num_new_tokens == 0: return target_input_ids else: - device = assistant_candidate_ids.device - transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -(len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]) :].cpu()].to(device) + transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens :]] return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor: """ Return the target logits that correspond to the assistant logits. """ - device = assistant_logits.device target_vocab_size: int = len(self._target_tokenizer.get_vocab()) target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], target_vocab_size) - target_logits: torch.FloatTensor = torch.full(target_shape, -float("inf")).to(device) + target_logits: torch.FloatTensor = torch.full(target_shape, -float("inf")).to(self.assistant_model_device) assistant_logits_supported_mask: torch.BoolTensor = assistant_logits > -float("inf") assistant_logits_supported_indices: torch.IntTensor = assistant_logits_supported_mask.nonzero(as_tuple=True)[ -1 - ].cpu() - target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_logits_supported_indices].to(device) + ] + target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_logits_supported_indices] target_logits[..., target_logits_supported_indices] = assistant_logits[..., assistant_logits_supported_mask] if hasattr(self, '_padding_size'): - padding = torch.full((target_logits.size(0), target_logits.size(1), self._padding_size), -float("inf")).to(device) + padding = torch.full((target_logits.size(0), target_logits.size(1), self._padding_size), -float("inf")).to(self.assistant_model_device) target_logits = torch.cat((target_logits, padding), dim=2) return target_logits @@ -643,7 +643,7 @@ class AssistantVocabTranslatorCache: @classmethod def get_translator( - cls, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", target_vocab_size: int + cls, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", assistant_model_device, target_vocab_size: int, assistant_vocab_size: int ) -> AssistantToTargetTranslator: with cls._lock: assistant_dict = cls._cache.get(target_tokenizer) @@ -653,7 +653,7 @@ def get_translator( mapping = assistant_dict.get(assistant_tokenizer) if mapping is None: - mapping = AssistantToTargetTranslator(target_tokenizer, assistant_tokenizer, target_vocab_size) + mapping = AssistantToTargetTranslator(target_tokenizer, assistant_tokenizer, assistant_model_device, target_vocab_size, assistant_vocab_size) assistant_dict[assistant_tokenizer] = mapping return mapping @@ -697,7 +697,8 @@ def __init__( logits_processor: "LogitsProcessorList" = None, ): # Initialize translator before parent class - self._atm_translator = AssistantVocabTranslatorCache.get_translator(target_tokenizer, assistant_tokenizer, target_vocab_size) + self._atm_translator = AssistantVocabTranslatorCache.get_translator(target_tokenizer, assistant_tokenizer, assistant_model.device, + target_vocab_size, assistant_model.config.vocab_size) super().__init__( input_ids, assistant_model, diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index e0f27c98a6bff5..76bffdbd753a99 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1860,15 +1860,14 @@ class SuppressTokensLogitsProcessor(LogitsProcessor): ``` """ - def __init__(self, suppress_tokens): - self.suppress_tokens = suppress_tokens + def __init__(self, suppress_tokens, assistant_vocab_size, assistant_model_device): + suppress_tokens = suppress_tokens + vocab_tensor = torch.arange(assistant_vocab_size, device=assistant_model_device) + self.suppress_token_mask = isin_mps_friendly(vocab_tensor, suppress_tokens) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens.to(scores.device)) - scores = torch.where(suppress_token_mask, -float("inf"), scores) - return scores + return scores.masked_fill_(self.suppress_token_mask, -float("inf")) class WhisperTimeStampLogitsProcessor(LogitsProcessor): From ff7977e7d0fca0436c04c8bb9b1398774a1770bf Mon Sep 17 00:00:00 2001 From: jmamou Date: Thu, 12 Dec 2024 04:06:24 -0800 Subject: [PATCH 10/29] AssistantToTargetTranslator, SuppressTokensLogitsProcessor and tokenizers mapping improvements --- .../generation/candidate_generator.py | 90 +++++++++++++------ src/transformers/generation/logits_process.py | 15 ++-- src/transformers/generation/utils.py | 2 +- tests/generation/test_configuration_utils.py | 3 +- 4 files changed, 74 insertions(+), 36 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 437241d131adc5..608ba5a5d993ea 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -564,20 +564,27 @@ class AssistantToTargetTranslator: Translate the assistant into the target universe. """ - def __init__(self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", assistant_model_device, target_vocab_size: int, - assistant_vocab_size:int): + def __init__( + self, + target_tokenizer: "PreTrainedTokenizerBase", + assistant_tokenizer: "PreTrainedTokenizerBase", + assistant_model_device, + target_vocab_size: int, + assistant_vocab_size: int, + ): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer - self.assistant_model_device = assistant_model_device - target_tokenizer_vocab_size: int = len(target_tokenizer.get_vocab()) + self._assistant_model_device = assistant_model_device + target_tokenizer_vocab_size: int = len(target_tokenizer) # paddind is required in the case that target_vocab_size is bigger than set(target_tokenizer.get_vocab().keys()) if target_tokenizer_vocab_size < target_vocab_size: self._padding_size = target_vocab_size - target_tokenizer_vocab_size self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids() - self._suppress_input_ids: list[int] = self._get_suppress_input_ids() self.logits_processors: LogitsProcessorList = LogitsProcessorList( [ - SuppressTokensLogitsProcessor(self._suppress_input_ids, assistant_vocab_size, self.assistant_model_device) + SuppressTokensLogitsProcessor( + self._get_mapped_input_ids(), assistant_vocab_size, self._assistant_model_device + ) ] ) @@ -585,17 +592,19 @@ def _get_assistant_to_target_input_ids(self): target_vocab = self._target_tokenizer.get_vocab() assistant_vocab = self._assistant_tokenizer.get_vocab() max_assistant_index = max(assistant_vocab.values()) - assistant_to_target_input_ids = torch.full((max_assistant_index+1,), -1, dtype=int) # -1 means not in target vocab + assistant_to_target_input_ids = torch.full( + (max_assistant_index + 1,), -1, dtype=int + ) # -1 means not in target vocab for tok, idx in assistant_vocab.items(): if tok in target_vocab: assistant_to_target_input_ids[idx] = target_vocab[tok] - return assistant_to_target_input_ids.to(self.assistant_model_device) - - def _get_suppress_input_ids(self) -> list[int]: + return assistant_to_target_input_ids.to(self._assistant_model_device) + + def _get_mapped_input_ids(self) -> list[int]: """ - Get the input ids that are in the assistant vocab but not in the target vocab. + Get the input ids that are both in the assistant vocab and in the target vocab. """ - return torch.where(self._assistant_to_target_input_ids==-1)[0] + return torch.where(self._assistant_to_target_input_ids != -1)[0] def get_target_ids( self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor @@ -605,30 +614,35 @@ def get_target_ids( Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens. Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids. """ - + num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1] if num_new_tokens == 0: return target_input_ids else: - transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens :]] + transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]] return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor: """ Return the target logits that correspond to the assistant logits. """ - target_vocab_size: int = len(self._target_tokenizer.get_vocab()) + target_vocab_size: int = len(self._target_tokenizer) target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], target_vocab_size) - target_logits: torch.FloatTensor = torch.full(target_shape, -float("inf")).to(self.assistant_model_device) + target_logits: torch.FloatTensor = torch.full(target_shape, -float("inf")).to(self._assistant_model_device) assistant_logits_supported_mask: torch.BoolTensor = assistant_logits > -float("inf") assistant_logits_supported_indices: torch.IntTensor = assistant_logits_supported_mask.nonzero(as_tuple=True)[ -1 ] target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_logits_supported_indices] target_logits[..., target_logits_supported_indices] = assistant_logits[..., assistant_logits_supported_mask] - if hasattr(self, '_padding_size'): - padding = torch.full((target_logits.size(0), target_logits.size(1), self._padding_size), -float("inf")).to(self.assistant_model_device) - target_logits = torch.cat((target_logits, padding), dim=2) + if hasattr(self, "_padding_size"): + padding = torch.full((target_logits.size(0), target_logits.size(1), self._padding_size), -float("inf")).to( + self._assistant_model_device + ) + if self._target_tokenizer.padding_side == "right": + target_logits = torch.cat((target_logits, padding), dim=2) + elif self._padding_side == "left": + target_logits = torch.cat((padding, target_logits), dim=2) return target_logits @@ -643,7 +657,12 @@ class AssistantVocabTranslatorCache: @classmethod def get_translator( - cls, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", assistant_model_device, target_vocab_size: int, assistant_vocab_size: int + cls, + target_tokenizer: "PreTrainedTokenizerBase", + assistant_tokenizer: "PreTrainedTokenizerBase", + assistant_model_device, + target_vocab_size: int, + assistant_vocab_size: int, ) -> AssistantToTargetTranslator: with cls._lock: assistant_dict = cls._cache.get(target_tokenizer) @@ -653,7 +672,13 @@ def get_translator( mapping = assistant_dict.get(assistant_tokenizer) if mapping is None: - mapping = AssistantToTargetTranslator(target_tokenizer, assistant_tokenizer, assistant_model_device, target_vocab_size, assistant_vocab_size) + mapping = AssistantToTargetTranslator( + target_tokenizer, + assistant_tokenizer, + assistant_model_device, + target_vocab_size, + assistant_vocab_size, + ) assistant_dict[assistant_tokenizer] = mapping return mapping @@ -697,8 +722,13 @@ def __init__( logits_processor: "LogitsProcessorList" = None, ): # Initialize translator before parent class - self._atm_translator = AssistantVocabTranslatorCache.get_translator(target_tokenizer, assistant_tokenizer, assistant_model.device, - target_vocab_size, assistant_model.config.vocab_size) + self._atm_translator = AssistantVocabTranslatorCache.get_translator( + target_tokenizer, + assistant_tokenizer, + assistant_model.device, + target_vocab_size, + assistant_model.config.vocab_size, + ) super().__init__( input_ids, assistant_model, @@ -735,7 +765,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, generation_args["generation_config"].return_dict_in_generate = True # Generate and process outputs using translator - generation_args['logits_processor'] = self._atm_translator.logits_processors + generation_args["logits_processor"] = self._atm_translator.logits_processors assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs) self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values @@ -743,7 +773,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, # Use translator to convert tokens and logits self._prev_assistant_ids = assistant_output.sequences - target_ids = self._atm_translator.get_target_ids(assistant_input_ids, target_input_ids, self._prev_assistant_ids) + target_ids = self._atm_translator.get_target_ids( + assistant_input_ids, target_input_ids, self._prev_assistant_ids + ) self._target_seq_len_with_candidates = target_ids.shape[-1] target_logits = self._atm_translator.get_target_logits(candidate_logits) @@ -773,13 +805,13 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to if self._prev_assistant_ids is None: assistant_input_ids = assistant_new_ids else: - tokens_to_remove = self._target_seq_len_with_candidates+1-target_seq_len + tokens_to_remove = self._target_seq_len_with_candidates + 1 - target_seq_len # If the number of new tokens is greater than zero, truncate the previous assistant IDs if tokens_to_remove > 0: - self._prev_assistant_ids = self._prev_assistant_ids[:,:-tokens_to_remove] - assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) + self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove] + assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) assistant_input_ids = assistant_input_ids.to(torch.int) - + return assistant_input_ids diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 76bffdbd753a99..5e8511ce7cd371 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1860,14 +1860,19 @@ class SuppressTokensLogitsProcessor(LogitsProcessor): ``` """ - def __init__(self, suppress_tokens, assistant_vocab_size, assistant_model_device): - suppress_tokens = suppress_tokens - vocab_tensor = torch.arange(assistant_vocab_size, device=assistant_model_device) - self.suppress_token_mask = isin_mps_friendly(vocab_tensor, suppress_tokens) + def __init__( + self, mapped_tokens, assistant_vocab_size, assistant_model_device, filter_value: float = -float("Inf") + ): + # Initialize a tensor of size assistant_vocab_size with True values + self.suppress_token_mask = torch.ones(assistant_vocab_size, dtype=torch.bool, device=assistant_model_device) + + # Set the values at indices specified in mapped_tokens to False + self.suppress_token_mask[mapped_tokens] = False + self.filter_value = filter_value @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - return scores.masked_fill_(self.suppress_token_mask, -float("inf")) + return scores.masked_fill_(self.suppress_token_mask, self.filter_value) class WhisperTimeStampLogitsProcessor(LogitsProcessor): diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7f319d59d14ae0..c577f78966040b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -858,7 +858,7 @@ def _get_candidate_generator( logits_processor=logits_processor, target_tokenizer=target_tokenizer, assistant_tokenizer=assistant_tokenizer, - target_vocab_size=self.config.vocab_size # required in the case that self.config.vocab_size is different from the length of target_tokenizer.get_vocab() + target_vocab_size=self.config.vocab_size, # required in the case that self.config.vocab_size is different from the length of target_tokenizer.get_vocab() ) case False: candidate_generator = AssistedCandidateGeneratorDifferentTokenizers( diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index f4bd551bd7a4fc..970ab61eb63143 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -599,7 +599,8 @@ def test_serialize_generation_suppress_tokens(self): new_config = GenerationConfig.from_pretrained(tmp_dir) self.assertSequenceEqual(new_config.suppress_tokens, suppress_tokens) - suppress_processor = SuppressTokensLogitsProcessor(suppress_tokens=new_config.suppress_tokens) + # TODO + suppress_processor = SuppressTokensLogitsProcessor(mapped_tokens=new_config.suppress_tokens) self.assertSequenceEqual(suppress_processor.suppress_tokens, suppress_tokens) def test_serialize_generation_guidance_scale(self): From 38d81b1b3015763e08c13c53b36e9a06c5e528d1 Mon Sep 17 00:00:00 2001 From: jmamou Date: Thu, 12 Dec 2024 04:09:45 -0800 Subject: [PATCH 11/29] padding size --- src/transformers/generation/candidate_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 608ba5a5d993ea..f9ee82d6aacb92 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -641,7 +641,7 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT ) if self._target_tokenizer.padding_side == "right": target_logits = torch.cat((target_logits, padding), dim=2) - elif self._padding_side == "left": + elif self._target_tokenizer.padding_side == "left": target_logits = torch.cat((padding, target_logits), dim=2) return target_logits From 6a7d3b3650614a5541a013e6e72f51c47e547f36 Mon Sep 17 00:00:00 2001 From: jmamou Date: Thu, 12 Dec 2024 04:36:04 -0800 Subject: [PATCH 12/29] padding improvement --- src/transformers/generation/candidate_generator.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index f9ee82d6aacb92..58c71c37b00f12 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -639,10 +639,11 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT padding = torch.full((target_logits.size(0), target_logits.size(1), self._padding_size), -float("inf")).to( self._assistant_model_device ) - if self._target_tokenizer.padding_side == "right": - target_logits = torch.cat((target_logits, padding), dim=2) - elif self._target_tokenizer.padding_side == "left": - target_logits = torch.cat((padding, target_logits), dim=2) + padding_side_actions = { + "right": lambda: torch.cat((target_logits, padding), dim=2), + "left": lambda: torch.cat((padding, target_logits), dim=2) + } + target_logits = padding_side_actions.get(self._target_tokenizer.padding_side, lambda: target_logits)() return target_logits From e4e53b9ff14de17ede93680b7b3d450b95b64233 Mon Sep 17 00:00:00 2001 From: jmamou Date: Thu, 12 Dec 2024 07:09:15 -0800 Subject: [PATCH 13/29] fix and simplify get_target_logits --- .../generation/candidate_generator.py | 45 ++++++++++--------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 58c71c37b00f12..79179c79648de4 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -575,10 +575,7 @@ def __init__( self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer self._assistant_model_device = assistant_model_device - target_tokenizer_vocab_size: int = len(target_tokenizer) - # paddind is required in the case that target_vocab_size is bigger than set(target_tokenizer.get_vocab().keys()) - if target_tokenizer_vocab_size < target_vocab_size: - self._padding_size = target_vocab_size - target_tokenizer_vocab_size + self.target_vocab_size: int = target_vocab_size self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids() self.logits_processors: LogitsProcessorList = LogitsProcessorList( [ @@ -626,24 +623,30 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT """ Return the target logits that correspond to the assistant logits. """ - target_vocab_size: int = len(self._target_tokenizer) - target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], target_vocab_size) + + target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size) target_logits: torch.FloatTensor = torch.full(target_shape, -float("inf")).to(self._assistant_model_device) - assistant_logits_supported_mask: torch.BoolTensor = assistant_logits > -float("inf") - assistant_logits_supported_indices: torch.IntTensor = assistant_logits_supported_mask.nonzero(as_tuple=True)[ - -1 - ] - target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_logits_supported_indices] - target_logits[..., target_logits_supported_indices] = assistant_logits[..., assistant_logits_supported_mask] - if hasattr(self, "_padding_size"): - padding = torch.full((target_logits.size(0), target_logits.size(1), self._padding_size), -float("inf")).to( - self._assistant_model_device - ) - padding_side_actions = { - "right": lambda: torch.cat((target_logits, padding), dim=2), - "left": lambda: torch.cat((padding, target_logits), dim=2) - } - target_logits = padding_side_actions.get(self._target_tokenizer.padding_side, lambda: target_logits)() + assistant_indices = self._assistant_to_target_input_ids != -1 # Mask for valid indices + target_indices = self._assistant_to_target_input_ids[assistant_indices] # Exclude invalid indices + valid_assistant_logits = assistant_logits[..., :self._assistant_to_target_input_ids.shape[0]] + + target_logits[..., target_indices] = valid_assistant_logits[..., assistant_indices] + + # assistant_logits_supported_mask: torch.BoolTensor = assistant_logits > -float("inf") + # assistant_logits_supported_indices: torch.IntTensor = assistant_logits_supported_mask.nonzero(as_tuple=True)[ + # -1 + # ] + # target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_logits_supported_indices] + # target_logits[..., target_logits_supported_indices] = assistant_logits[..., assistant_logits_supported_mask] + # if hasattr(self, "_padding_size"): + # padding = torch.full((target_logits.size(0), target_logits.size(1), self._padding_size), -float("inf")).to( + # self._assistant_model_device + # ) + # padding_side_actions = { + # "right": lambda: torch.cat((target_logits, padding), dim=2), + # "left": lambda: torch.cat((padding, target_logits), dim=2) + # } + # target_logits = padding_side_actions.get(self._target_tokenizer.padding_side, lambda: target_logits)() return target_logits From a19a9defd94142cfee36544878eb0569a28086c6 Mon Sep 17 00:00:00 2001 From: jmamou Date: Thu, 12 Dec 2024 09:50:14 -0800 Subject: [PATCH 14/29] renaming in get_target_logits --- src/transformers/generation/candidate_generator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 79179c79648de4..b4fbd08292bf57 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -626,11 +626,11 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size) target_logits: torch.FloatTensor = torch.full(target_shape, -float("inf")).to(self._assistant_model_device) - assistant_indices = self._assistant_to_target_input_ids != -1 # Mask for valid indices - target_indices = self._assistant_to_target_input_ids[assistant_indices] # Exclude invalid indices + assistant_indices_mask = self._assistant_to_target_input_ids != -1 # Mask for valid indices + target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] # Exclude invalid indices valid_assistant_logits = assistant_logits[..., :self._assistant_to_target_input_ids.shape[0]] - target_logits[..., target_indices] = valid_assistant_logits[..., assistant_indices] + target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask] # assistant_logits_supported_mask: torch.BoolTensor = assistant_logits > -float("inf") # assistant_logits_supported_indices: torch.IntTensor = assistant_logits_supported_mask.nonzero(as_tuple=True)[ From c4e4186c1dea890203a6e353fb82a54f819bcb10 Mon Sep 17 00:00:00 2001 From: jmamou Date: Sun, 15 Dec 2024 03:35:54 -0800 Subject: [PATCH 15/29] minor --- .../generation/candidate_generator.py | 21 ++++--------------- src/transformers/generation/utils.py | 3 ++- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index b4fbd08292bf57..c15d4536d794aa 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -626,27 +626,14 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size) target_logits: torch.FloatTensor = torch.full(target_shape, -float("inf")).to(self._assistant_model_device) - assistant_indices_mask = self._assistant_to_target_input_ids != -1 # Mask for valid indices - target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] # Exclude invalid indices + # Mask for valid indices + assistant_indices_mask = self._assistant_to_target_input_ids != -1 + # Exclude invalid indices + target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] valid_assistant_logits = assistant_logits[..., :self._assistant_to_target_input_ids.shape[0]] target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask] - # assistant_logits_supported_mask: torch.BoolTensor = assistant_logits > -float("inf") - # assistant_logits_supported_indices: torch.IntTensor = assistant_logits_supported_mask.nonzero(as_tuple=True)[ - # -1 - # ] - # target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_logits_supported_indices] - # target_logits[..., target_logits_supported_indices] = assistant_logits[..., assistant_logits_supported_mask] - # if hasattr(self, "_padding_size"): - # padding = torch.full((target_logits.size(0), target_logits.size(1), self._padding_size), -float("inf")).to( - # self._assistant_model_device - # ) - # padding_side_actions = { - # "right": lambda: torch.cat((target_logits, padding), dim=2), - # "left": lambda: torch.cat((padding, target_logits), dim=2) - # } - # target_logits = padding_side_actions.get(self._target_tokenizer.padding_side, lambda: target_logits)() return target_logits diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c577f78966040b..e44650e0efd0db 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -858,7 +858,8 @@ def _get_candidate_generator( logits_processor=logits_processor, target_tokenizer=target_tokenizer, assistant_tokenizer=assistant_tokenizer, - target_vocab_size=self.config.vocab_size, # required in the case that self.config.vocab_size is different from the length of target_tokenizer.get_vocab() + # required in the case that self.config.vocab_size is different from the length of target_tokenizer.get_vocab() + target_vocab_size=self.config.vocab_size, ) case False: candidate_generator = AssistedCandidateGeneratorDifferentTokenizers( From 0ec0788cfba49738f5bf30d8c5d5ec582981d171 Mon Sep 17 00:00:00 2001 From: jmamou Date: Sun, 15 Dec 2024 04:48:25 -0800 Subject: [PATCH 16/29] add filter_value and suppress_tokens_id --- .../generation/candidate_generator.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index c15d4536d794aa..48d514dd642dbf 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -571,16 +571,20 @@ def __init__( assistant_model_device, target_vocab_size: int, assistant_vocab_size: int, + filter_value: float = -float("Inf"), + suppress_tokens_id: int = -1 ): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer self._assistant_model_device = assistant_model_device self.target_vocab_size: int = target_vocab_size + self.filter_value: float = filter_value + self.suppress_tokens_id: int = suppress_tokens_id self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids() self.logits_processors: LogitsProcessorList = LogitsProcessorList( [ SuppressTokensLogitsProcessor( - self._get_mapped_input_ids(), assistant_vocab_size, self._assistant_model_device + self._get_mapped_input_ids(), assistant_vocab_size, self._assistant_model_device, self.filter_value ) ] ) @@ -590,8 +594,8 @@ def _get_assistant_to_target_input_ids(self): assistant_vocab = self._assistant_tokenizer.get_vocab() max_assistant_index = max(assistant_vocab.values()) assistant_to_target_input_ids = torch.full( - (max_assistant_index + 1,), -1, dtype=int - ) # -1 means not in target vocab + (max_assistant_index + 1,), self.suppress_tokens_id, dtype=int + ) for tok, idx in assistant_vocab.items(): if tok in target_vocab: assistant_to_target_input_ids[idx] = target_vocab[tok] @@ -601,7 +605,7 @@ def _get_mapped_input_ids(self) -> list[int]: """ Get the input ids that are both in the assistant vocab and in the target vocab. """ - return torch.where(self._assistant_to_target_input_ids != -1)[0] + return torch.where(self._assistant_to_target_input_ids != self.suppress_tokens_id)[0] def get_target_ids( self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor @@ -625,9 +629,9 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT """ target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size) - target_logits: torch.FloatTensor = torch.full(target_shape, -float("inf")).to(self._assistant_model_device) + target_logits: torch.FloatTensor = torch.full(target_shape, self.filter_value).to(self._assistant_model_device) # Mask for valid indices - assistant_indices_mask = self._assistant_to_target_input_ids != -1 + assistant_indices_mask = self._assistant_to_target_input_ids != self.suppress_tokens_id # Exclude invalid indices target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] valid_assistant_logits = assistant_logits[..., :self._assistant_to_target_input_ids.shape[0]] From 200f7a0aba92a92c2acfd8fe5573e1290be48a04 Mon Sep 17 00:00:00 2001 From: jmamou Date: Sun, 15 Dec 2024 06:09:01 -0800 Subject: [PATCH 17/29] style + rename --- .../generation/candidate_generator.py | 18 ++++++++---------- src/transformers/generation/utils.py | 2 +- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 48d514dd642dbf..512ad33046e5c6 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -572,7 +572,7 @@ def __init__( target_vocab_size: int, assistant_vocab_size: int, filter_value: float = -float("Inf"), - suppress_tokens_id: int = -1 + suppress_tokens_id: int = -1, ): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer @@ -584,7 +584,7 @@ def __init__( self.logits_processors: LogitsProcessorList = LogitsProcessorList( [ SuppressTokensLogitsProcessor( - self._get_mapped_input_ids(), assistant_vocab_size, self._assistant_model_device, self.filter_value + self._get_mapped_input_ids(), assistant_vocab_size, self._assistant_model_device, self.filter_value ) ] ) @@ -593,9 +593,7 @@ def _get_assistant_to_target_input_ids(self): target_vocab = self._target_tokenizer.get_vocab() assistant_vocab = self._assistant_tokenizer.get_vocab() max_assistant_index = max(assistant_vocab.values()) - assistant_to_target_input_ids = torch.full( - (max_assistant_index + 1,), self.suppress_tokens_id, dtype=int - ) + assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.suppress_tokens_id, dtype=int) for tok, idx in assistant_vocab.items(): if tok in target_vocab: assistant_to_target_input_ids[idx] = target_vocab[tok] @@ -631,10 +629,10 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size) target_logits: torch.FloatTensor = torch.full(target_shape, self.filter_value).to(self._assistant_model_device) # Mask for valid indices - assistant_indices_mask = self._assistant_to_target_input_ids != self.suppress_tokens_id + assistant_indices_mask = self._assistant_to_target_input_ids != self.suppress_tokens_id # Exclude invalid indices - target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] - valid_assistant_logits = assistant_logits[..., :self._assistant_to_target_input_ids.shape[0]] + target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] + valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]] target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask] @@ -764,7 +762,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs) self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values - candidate_logits = torch.stack(assistant_output.scores, dim=1) + assistant_logits = torch.stack(assistant_output.scores, dim=1) # Use translator to convert tokens and logits self._prev_assistant_ids = assistant_output.sequences @@ -772,7 +770,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, assistant_input_ids, target_input_ids, self._prev_assistant_ids ) self._target_seq_len_with_candidates = target_ids.shape[-1] - target_logits = self._atm_translator.get_target_logits(candidate_logits) + target_logits = self._atm_translator.get_target_logits(assistant_logits) return target_ids, target_logits diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e44650e0efd0db..d7d9757d3e4f39 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -859,7 +859,7 @@ def _get_candidate_generator( target_tokenizer=target_tokenizer, assistant_tokenizer=assistant_tokenizer, # required in the case that self.config.vocab_size is different from the length of target_tokenizer.get_vocab() - target_vocab_size=self.config.vocab_size, + target_vocab_size=self.config.vocab_size, ) case False: candidate_generator = AssistedCandidateGeneratorDifferentTokenizers( From 95bfa2cea40f2e473a770e13770a0a9243ba5962 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 16 Dec 2024 04:36:45 -0800 Subject: [PATCH 18/29] remove TODO --- tests/generation/test_configuration_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index 970ab61eb63143..f17a34a72569e2 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -599,7 +599,6 @@ def test_serialize_generation_suppress_tokens(self): new_config = GenerationConfig.from_pretrained(tmp_dir) self.assertSequenceEqual(new_config.suppress_tokens, suppress_tokens) - # TODO suppress_processor = SuppressTokensLogitsProcessor(mapped_tokens=new_config.suppress_tokens) self.assertSequenceEqual(suppress_processor.suppress_tokens, suppress_tokens) From 1cbc871f6e62f8750fc4bd8c945bf5aad90ab623 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 16 Dec 2024 04:38:33 -0800 Subject: [PATCH 19/29] restore original SelectTokensLogitsProcessor with modification --- src/transformers/generation/logits_process.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 5e8511ce7cd371..6b0f3ca443e7b1 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1830,6 +1830,20 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores_processed +class SelectTokensLogitsProcessor(LogitsProcessor): + def __init__( + self, mapped_tokens, assistant_vocab_size, assistant_model_device, filter_value: float = -float("Inf") + ): + # Initialize a tensor of size assistant_vocab_size with True values + self.suppress_token_mask = torch.ones(assistant_vocab_size, dtype=torch.bool, device=assistant_model_device) + + # Set the values at indices specified in mapped_tokens to False + self.suppress_token_mask[mapped_tokens] = False + self.filter_value = filter_value + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + return scores.masked_fill_(self.suppress_token_mask, self.filter_value) class SuppressTokensLogitsProcessor(LogitsProcessor): r""" @@ -1860,19 +1874,16 @@ class SuppressTokensLogitsProcessor(LogitsProcessor): ``` """ - def __init__( - self, mapped_tokens, assistant_vocab_size, assistant_model_device, filter_value: float = -float("Inf") - ): - # Initialize a tensor of size assistant_vocab_size with True values - self.suppress_token_mask = torch.ones(assistant_vocab_size, dtype=torch.bool, device=assistant_model_device) - - # Set the values at indices specified in mapped_tokens to False - self.suppress_token_mask[mapped_tokens] = False + def __init__(self, suppress_tokens, device: str = "cpu", filter_value: float = -float("Inf")): + self.suppress_tokens = torch.tensor(list(suppress_tokens), device=device) self.filter_value = filter_value @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - return scores.masked_fill_(self.suppress_token_mask, self.filter_value) + vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) + suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens) + scores = torch.where(suppress_token_mask, self.filter_value, scores) + return scores class WhisperTimeStampLogitsProcessor(LogitsProcessor): From 4a94849c5a76c3c15d332758bc892147f9bd90f7 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 16 Dec 2024 04:42:09 -0800 Subject: [PATCH 20/29] fix style --- src/transformers/generation/logits_process.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 6b0f3ca443e7b1..c9378e15dfbf27 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1830,6 +1830,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores_processed + class SelectTokensLogitsProcessor(LogitsProcessor): def __init__( self, mapped_tokens, assistant_vocab_size, assistant_model_device, filter_value: float = -float("Inf") @@ -1845,6 +1846,7 @@ def __init__( def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: return scores.masked_fill_(self.suppress_token_mask, self.filter_value) + class SuppressTokensLogitsProcessor(LogitsProcessor): r""" This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so From f1b6b08407a5f85bc902b93bb6bc4c8f173180d3 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 16 Dec 2024 04:50:07 -0800 Subject: [PATCH 21/29] fix _update_past_and_masks and optimize code --- .../generation/candidate_generator.py | 42 ++++++++----------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 512ad33046e5c6..aebe8231b6163e 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -244,13 +244,15 @@ def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> Tuple[int, int]: min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0) return min_new_tokens, max_new_tokens - def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool: + def _update_past_and_masks( + self, input_ids: torch.LongTensor, remove_from_pkv: int = 0, num_added_tokens: int = 1 + ) -> bool: """Update past key values and attention masks for subsequent generation rounds.""" has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None if has_past_key_values: new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv self.assistant_kwargs["past_key_values"] = _crop_past_key_values( - self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 + self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens ) self.assistant_kwargs = _prepare_attention_mask( self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder @@ -582,11 +584,7 @@ def __init__( self.suppress_tokens_id: int = suppress_tokens_id self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids() self.logits_processors: LogitsProcessorList = LogitsProcessorList( - [ - SuppressTokensLogitsProcessor( - self._get_mapped_input_ids(), assistant_vocab_size, self._assistant_model_device, self.filter_value - ) - ] + [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] ) def _get_assistant_to_target_input_ids(self): @@ -599,11 +597,11 @@ def _get_assistant_to_target_input_ids(self): assistant_to_target_input_ids[idx] = target_vocab[tok] return assistant_to_target_input_ids.to(self._assistant_model_device) - def _get_mapped_input_ids(self) -> list[int]: + def _get_suppress_input_ids(self) -> list[int]: """ - Get the input ids that are both in the assistant vocab and in the target vocab. + Get the input ids that are in the assistant vocab but not in the target vocab. """ - return torch.where(self._assistant_to_target_input_ids != self.suppress_tokens_id)[0] + return torch.where(self._assistant_to_target_input_ids == self.suppress_tokens_id)[0] def get_target_ids( self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor @@ -741,17 +739,15 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, """ Simplified version of get_candidates that uses the translator cache for token conversion. """ - input_ids = input_ids.to(self.assistant_model.device) - target_input_ids = input_ids.clone() - assistant_input_ids = self._prepare_assistant_input_ids(target_input_ids) + target_input_ids = input_ids.to(self.assistant_model.device) + assistant_input_ids, num_added_tokens = self._prepare_assistant_input_ids(target_input_ids) min_new_tokens, max_new_tokens = self._calculate_new_tokens(target_input_ids) if max_new_tokens == 0: return input_ids, None - self._update_past_and_masks(assistant_input_ids) + self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens) generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) - self.assistant_kwargs.pop("attention_mask", None) # Ensure scores are returned generation_args["generation_config"].output_scores = True @@ -759,20 +755,16 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, # Generate and process outputs using translator generation_args["logits_processor"] = self._atm_translator.logits_processors - assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs) - self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values - - assistant_logits = torch.stack(assistant_output.scores, dim=1) + self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args) # Use translator to convert tokens and logits - self._prev_assistant_ids = assistant_output.sequences - target_ids = self._atm_translator.get_target_ids( + target_candidate_ids = self._atm_translator.get_target_ids( assistant_input_ids, target_input_ids, self._prev_assistant_ids ) - self._target_seq_len_with_candidates = target_ids.shape[-1] - target_logits = self._atm_translator.get_target_logits(assistant_logits) + self._target_seq_len_with_candidates = target_candidate_ids.shape[-1] + target_candidate_logits = self._atm_translator.get_target_logits(assistant_candidate_logits) - return target_ids, target_logits + return target_candidate_ids, target_candidate_logits def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> torch.LongTensor: """ @@ -805,7 +797,7 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) assistant_input_ids = assistant_input_ids.to(torch.int) - return assistant_input_ids + return assistant_input_ids, len(assistant_new_ids[0]) class PromptLookupCandidateGenerator(CandidateGenerator): From df68533d29440163c5cad0c69cbdaa398d1d745e Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 16 Dec 2024 05:02:14 -0800 Subject: [PATCH 22/29] remove assistant_vocab_size arg --- src/transformers/generation/candidate_generator.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index aebe8231b6163e..04dd5c273e669a 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -572,7 +572,6 @@ def __init__( assistant_tokenizer: "PreTrainedTokenizerBase", assistant_model_device, target_vocab_size: int, - assistant_vocab_size: int, filter_value: float = -float("Inf"), suppress_tokens_id: int = -1, ): @@ -652,8 +651,7 @@ def get_translator( target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", assistant_model_device, - target_vocab_size: int, - assistant_vocab_size: int, + target_vocab_size: int ) -> AssistantToTargetTranslator: with cls._lock: assistant_dict = cls._cache.get(target_tokenizer) @@ -667,8 +665,7 @@ def get_translator( target_tokenizer, assistant_tokenizer, assistant_model_device, - target_vocab_size, - assistant_vocab_size, + target_vocab_size ) assistant_dict[assistant_tokenizer] = mapping @@ -717,8 +714,7 @@ def __init__( target_tokenizer, assistant_tokenizer, assistant_model.device, - target_vocab_size, - assistant_model.config.vocab_size, + target_vocab_size ) super().__init__( input_ids, From 35e354a70251fbd55d4506d04e11867a6570f318 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 16 Dec 2024 06:57:19 -0800 Subject: [PATCH 23/29] fix attention_mask --- src/transformers/generation/candidate_generator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 04dd5c273e669a..25eae27a5c5c5c 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -751,6 +751,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, # Generate and process outputs using translator generation_args["logits_processor"] = self._atm_translator.logits_processors + self.assistant_kwargs.pop("attention_mask", None) self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args) # Use translator to convert tokens and logits From a558bd0939a5d0f8b1aa231907a5fc856fd8b555 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 16 Dec 2024 07:12:14 -0800 Subject: [PATCH 24/29] call _prepare_attention_mask also if not has_past_key_values --- src/transformers/generation/candidate_generator.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 25eae27a5c5c5c..178448ecd2163c 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -254,10 +254,10 @@ def _update_past_and_masks( self.assistant_kwargs["past_key_values"] = _crop_past_key_values( self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens ) - self.assistant_kwargs = _prepare_attention_mask( - self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder - ) self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1]) + self.assistant_kwargs = _prepare_attention_mask( + self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder + ) return has_past_key_values def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict: @@ -751,7 +751,6 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, # Generate and process outputs using translator generation_args["logits_processor"] = self._atm_translator.logits_processors - self.assistant_kwargs.pop("attention_mask", None) self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args) # Use translator to convert tokens and logits From 5c3ad58211a356b95125723641b5441cef6982f7 Mon Sep 17 00:00:00 2001 From: jmamou Date: Tue, 17 Dec 2024 01:47:05 -0800 Subject: [PATCH 25/29] handling attention mask for first generation --- src/transformers/generation/candidate_generator.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 178448ecd2163c..a1a9ba26f217ba 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -254,10 +254,11 @@ def _update_past_and_masks( self.assistant_kwargs["past_key_values"] = _crop_past_key_values( self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens ) + self.assistant_kwargs = _prepare_attention_mask( + self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder + ) self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1]) - self.assistant_kwargs = _prepare_attention_mask( - self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder - ) + return has_past_key_values def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict: @@ -744,6 +745,10 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens) generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) + if self._prev_assistant_ids is None: + self.assistant_kwargs = _prepare_attention_mask( + self.assistant_kwargs, assistant_input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder + ) # Ensure scores are returned generation_args["generation_config"].output_scores = True From 811a4e536ecc11a9cc33678f24c465cee8e9b3c1 Mon Sep 17 00:00:00 2001 From: jmamou Date: Tue, 17 Dec 2024 02:31:34 -0800 Subject: [PATCH 26/29] comment --- .../generation/candidate_generator.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index a1a9ba26f217ba..26b2ecf3063e1a 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -258,7 +258,7 @@ def _update_past_and_masks( self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder ) self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1]) - + return has_past_key_values def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict: @@ -652,7 +652,7 @@ def get_translator( target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", assistant_model_device, - target_vocab_size: int + target_vocab_size: int, ) -> AssistantToTargetTranslator: with cls._lock: assistant_dict = cls._cache.get(target_tokenizer) @@ -663,10 +663,7 @@ def get_translator( mapping = assistant_dict.get(assistant_tokenizer) if mapping is None: mapping = AssistantToTargetTranslator( - target_tokenizer, - assistant_tokenizer, - assistant_model_device, - target_vocab_size + target_tokenizer, assistant_tokenizer, assistant_model_device, target_vocab_size ) assistant_dict[assistant_tokenizer] = mapping @@ -712,10 +709,7 @@ def __init__( ): # Initialize translator before parent class self._atm_translator = AssistantVocabTranslatorCache.get_translator( - target_tokenizer, - assistant_tokenizer, - assistant_model.device, - target_vocab_size + target_tokenizer, assistant_tokenizer, assistant_model.device, target_vocab_size ) super().__init__( input_ids, @@ -743,12 +737,15 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, if max_new_tokens == 0: return input_ids, None - self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens) - generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) if self._prev_assistant_ids is None: + # Prepare attention mask for the first generation + # For subsequent generations, the attention mask is updated in _update_past_and_masks self.assistant_kwargs = _prepare_attention_mask( self.assistant_kwargs, assistant_input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder - ) + ) + + self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens) + generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) # Ensure scores are returned generation_args["generation_config"].output_scores = True From 2dcc9edf9423cad9125e8bb8d47271e78946412c Mon Sep 17 00:00:00 2001 From: jmamou Date: Tue, 17 Dec 2024 02:37:59 -0800 Subject: [PATCH 27/29] restore test --- tests/generation/test_configuration_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index f17a34a72569e2..f4bd551bd7a4fc 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -599,7 +599,7 @@ def test_serialize_generation_suppress_tokens(self): new_config = GenerationConfig.from_pretrained(tmp_dir) self.assertSequenceEqual(new_config.suppress_tokens, suppress_tokens) - suppress_processor = SuppressTokensLogitsProcessor(mapped_tokens=new_config.suppress_tokens) + suppress_processor = SuppressTokensLogitsProcessor(suppress_tokens=new_config.suppress_tokens) self.assertSequenceEqual(suppress_processor.suppress_tokens, suppress_tokens) def test_serialize_generation_guidance_scale(self): From f2be0da20e9f3473d3236261a8cf0f47ed6e0587 Mon Sep 17 00:00:00 2001 From: jmamou Date: Tue, 17 Dec 2024 03:07:07 -0800 Subject: [PATCH 28/29] remove SelectTokensLogitsProcessor --- .../generation/candidate_generator.py | 4 ++-- src/transformers/generation/logits_process.py | 17 ----------------- 2 files changed, 2 insertions(+), 19 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 26b2ecf3063e1a..072fdf74e92577 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -738,8 +738,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, return input_ids, None if self._prev_assistant_ids is None: - # Prepare attention mask for the first generation - # For subsequent generations, the attention mask is updated in _update_past_and_masks + # Prepare attention mask for the first generation. + # For subsequent generations, the attention mask is updated in _update_past_and_masks. self.assistant_kwargs = _prepare_attention_mask( self.assistant_kwargs, assistant_input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder ) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index c9378e15dfbf27..98effe8263e9a9 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1830,23 +1830,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores_processed - -class SelectTokensLogitsProcessor(LogitsProcessor): - def __init__( - self, mapped_tokens, assistant_vocab_size, assistant_model_device, filter_value: float = -float("Inf") - ): - # Initialize a tensor of size assistant_vocab_size with True values - self.suppress_token_mask = torch.ones(assistant_vocab_size, dtype=torch.bool, device=assistant_model_device) - - # Set the values at indices specified in mapped_tokens to False - self.suppress_token_mask[mapped_tokens] = False - self.filter_value = filter_value - - @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - return scores.masked_fill_(self.suppress_token_mask, self.filter_value) - - class SuppressTokensLogitsProcessor(LogitsProcessor): r""" This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so From 83b825085020cd0b513e1473302c682963400665 Mon Sep 17 00:00:00 2001 From: jmamou Date: Tue, 17 Dec 2024 07:44:24 -0800 Subject: [PATCH 29/29] _update_past_and_masks implementation for USD --- .../generation/candidate_generator.py | 16 +++++++++------- src/transformers/generation/logits_process.py | 1 + 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 072fdf74e92577..a37481018086d0 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -737,13 +737,6 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, if max_new_tokens == 0: return input_ids, None - if self._prev_assistant_ids is None: - # Prepare attention mask for the first generation. - # For subsequent generations, the attention mask is updated in _update_past_and_masks. - self.assistant_kwargs = _prepare_attention_mask( - self.assistant_kwargs, assistant_input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder - ) - self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens) generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) @@ -764,6 +757,15 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, return target_candidate_ids, target_candidate_logits + def _update_past_and_masks(self, assistant_input_ids: torch.LongTensor, num_added_tokens: int = 1) -> bool: + if self._prev_assistant_ids is None: + # Prepare attention mask for the first generation. + # For subsequent generations, the attention mask is updated in super()_update_past_and_masks. + self.assistant_kwargs = _prepare_attention_mask( + self.assistant_kwargs, assistant_input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder + ) + return super()._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens) + def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> torch.LongTensor: """ Simplified token conversion that only processes new tokens. diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 98effe8263e9a9..e818b266cd7b7c 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1830,6 +1830,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores_processed + class SuppressTokensLogitsProcessor(LogitsProcessor): r""" This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so