Skip to content

Commit

Permalink
introduce self.prev_target_ids_len
Browse files Browse the repository at this point in the history
  • Loading branch information
keyboardAnt committed Dec 5, 2024
1 parent 1826a1c commit fcd129f
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def __init__(

self.target_tokenizer = target_tokenizer
self.assistant_tokenizer = assistant_tokenizer
self.prev_target_ids = None
self.prev_target_ids_len: Optional[int] = None
self.prev_assistant_ids = None
self.target_lookbehind = assistant_model.generation_config.target_lookbehind
self.assistant_lookbehind = assistant_model.generation_config.assistant_lookbehind
Expand Down Expand Up @@ -465,11 +465,11 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
new_target_ids = self._process_assistant_outputs(input_ids, assistant_output.sequences, assistant_input_ids)

# Update state
self.prev_target_ids = input_ids
self.prev_target_ids_len = input_ids.shape[1]
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
self.prev_assistant_ids = assistant_output.sequences

if input_ids.shape[1] >= new_target_ids.shape[1]:
if self.prev_target_ids_len >= new_target_ids.shape[1]:
return input_ids, None

return new_target_ids, None
Expand All @@ -482,9 +482,9 @@ def _prepare_assistant_input_ids(self, input_ids: torch.LongTensor) -> Tuple[tor
}
remove_from_pkv = 0

if self.prev_assistant_ids is not None and self.prev_target_ids.shape[1] > self.target_lookbehind:
if self.prev_assistant_ids is not None and self.prev_target_ids_len > self.target_lookbehind:
# input_ids contains all target prompt input ids and some new target input ids
start_index_in_target_window = self.prev_target_ids.shape[1] - self.target_lookbehind
start_index_in_target_window = self.prev_target_ids_len - self.target_lookbehind

new_assistant_ids = self.convert_source_tokens_to_target_tokens(
input_ids[:, start_index_in_target_window:], **convert_kwargs
Expand Down Expand Up @@ -516,7 +516,7 @@ def _prepare_assistant_input_ids(self, input_ids: torch.LongTensor) -> Tuple[tor
assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1)
else:
assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs)
self.prev_target_ids = input_ids
self.prev_target_ids_len = input_ids.shape[1]

return assistant_input_ids, remove_from_pkv

Expand Down

0 comments on commit fcd129f

Please sign in to comment.