Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix prepare + apply #7

Merged
merged 29 commits into from
Dec 17, 2024
Merged
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
6097a8d
fix prepare + apply
jmamou Dec 8, 2024
71562fc
move to cpu
jmamou Dec 8, 2024
3b4e9da
simplity suppress_tokens
jmamou Dec 9, 2024
1dcdae4
fix bugs and refacatoring
jmamou Dec 9, 2024
10d1e56
device move
jmamou Dec 9, 2024
f9a260f
handle self.config.vocab_size > len(target_tokenizer.get_vocab())
jmamou Dec 10, 2024
0d3310d
no need to normalize in candidate_generator
jmamou Dec 10, 2024
98cd50b
address Nadav's comments + minor
jmamou Dec 11, 2024
8260624
optimize device move + SuppressTokensLogitsProcessor
jmamou Dec 11, 2024
ff7977e
AssistantToTargetTranslator, SuppressTokensLogitsProcessor and tokeni…
jmamou Dec 12, 2024
38d81b1
padding size
jmamou Dec 12, 2024
6a7d3b3
padding improvement
jmamou Dec 12, 2024
e4e53b9
fix and simplify get_target_logits
jmamou Dec 12, 2024
a19a9de
renaming in get_target_logits
jmamou Dec 12, 2024
c4e4186
minor
jmamou Dec 15, 2024
0ec0788
add filter_value and suppress_tokens_id
jmamou Dec 15, 2024
200f7a0
style + rename
jmamou Dec 15, 2024
95bfa2c
remove TODO
jmamou Dec 16, 2024
1cbc871
restore original SelectTokensLogitsProcessor with modification
jmamou Dec 16, 2024
4a94849
fix style
jmamou Dec 16, 2024
f1b6b08
fix _update_past_and_masks and optimize code
jmamou Dec 16, 2024
df68533
remove assistant_vocab_size arg
jmamou Dec 16, 2024
35e354a
fix attention_mask
jmamou Dec 16, 2024
a558bd0
call _prepare_attention_mask also if not has_past_key_values
jmamou Dec 16, 2024
5c3ad58
handling attention mask for first generation
jmamou Dec 17, 2024
811a4e5
comment
jmamou Dec 17, 2024
2dcc9ed
restore test
jmamou Dec 17, 2024
f2be0da
remove SelectTokensLogitsProcessor
jmamou Dec 17, 2024
83b8250
_update_past_and_masks implementation for USD
jmamou Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,11 +626,11 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT

target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size)
target_logits: torch.FloatTensor = torch.full(target_shape, -float("inf")).to(self._assistant_model_device)
assistant_indices = self._assistant_to_target_input_ids != -1 # Mask for valid indices
target_indices = self._assistant_to_target_input_ids[assistant_indices] # Exclude invalid indices
assistant_indices_mask = self._assistant_to_target_input_ids != -1 # Mask for valid indices
target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] # Exclude invalid indices
jmamou marked this conversation as resolved.
Show resolved Hide resolved
valid_assistant_logits = assistant_logits[..., :self._assistant_to_target_input_ids.shape[0]]

target_logits[..., target_indices] = valid_assistant_logits[..., assistant_indices]
target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask]

# assistant_logits_supported_mask: torch.BoolTensor = assistant_logits > -float("inf")
# assistant_logits_supported_indices: torch.IntTensor = assistant_logits_supported_mask.nonzero(as_tuple=True)[
Expand Down