Skip to content

Commit

Permalink
renaming in get_target_logits
Browse files Browse the repository at this point in the history
  • Loading branch information
jmamou committed Dec 12, 2024
1 parent e4e53b9 commit a19a9de
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,11 +626,11 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT

target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size)
target_logits: torch.FloatTensor = torch.full(target_shape, -float("inf")).to(self._assistant_model_device)
assistant_indices = self._assistant_to_target_input_ids != -1 # Mask for valid indices
target_indices = self._assistant_to_target_input_ids[assistant_indices] # Exclude invalid indices
assistant_indices_mask = self._assistant_to_target_input_ids != -1 # Mask for valid indices
target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] # Exclude invalid indices
valid_assistant_logits = assistant_logits[..., :self._assistant_to_target_input_ids.shape[0]]

target_logits[..., target_indices] = valid_assistant_logits[..., assistant_indices]
target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask]

# assistant_logits_supported_mask: torch.BoolTensor = assistant_logits > -float("inf")
# assistant_logits_supported_indices: torch.IntTensor = assistant_logits_supported_mask.nonzero(as_tuple=True)[
Expand Down

0 comments on commit a19a9de

Please sign in to comment.