diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 25148b3a6b837c..f1ab4b5992efe1 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -717,13 +717,15 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, """ input_ids = input_ids.to(self.assistant_model.device) target_input_ids = input_ids.clone() - assistant_input_ids = self._prepare_assistant_input_ids(target_input_ids) + assistant_input_ids, remove_from_kv = self._prepare_assistant_input_ids( + target_input_ids + ) min_new_tokens, max_new_tokens = self._calculate_new_tokens(target_input_ids) if max_new_tokens == 0: return input_ids, None - self._update_past_and_masks(assistant_input_ids) + self._update_past_and_masks(assistant_input_ids, remove_from_kv) generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) self.assistant_kwargs.pop("attention_mask", None) @@ -745,12 +747,15 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, return target_ids, target_logits - def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> torch.LongTensor: + def _prepare_assistant_input_ids( + self, target_input_ids: torch.LongTensor + ) -> Tuple[torch.LongTensor, int]: """ Simplified token conversion that only processes new tokens. """ # Calculate new tokens since last call target_seq_len = target_input_ids.shape[-1] + remove_from_pkv = target_seq_len - 1 - self._prev_target_seq_len new_token_count = target_seq_len - self._prev_target_seq_len target_new_ids = target_input_ids[:, -new_token_count:] self._prev_target_seq_len = target_seq_len @@ -769,7 +774,7 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to else: self._prev_assistant_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) - return self._prev_assistant_ids + return self._prev_assistant_ids, remove_from_pkv class PromptLookupCandidateGenerator(CandidateGenerator):