Skip to content

Commit

Permalink
[Speculative Decoding] Move indices to device before filtering output (
Browse files Browse the repository at this point in the history
…#10850)

Co-authored-by: Yang Zheng(SW)(Alex) <[email protected]>
  • Loading branch information
zhengy001 and Yang Zheng(SW)(Alex) authored Dec 3, 2024
1 parent 9323a31 commit f6084f6
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ def sampler_output(
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)

# move indices to device to avoid stream sync
indices_of_seq_with_bonus_tokens = torch.tensor(
indices_of_seq_with_bonus_tokens, device=self.device)
filtered_model_outputs = self._filter_model_output(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True
Expand Down Expand Up @@ -189,7 +192,7 @@ def _expand_execute_model_request(
@staticmethod
def _filter_model_output(
expanded_batch_outputs: List[SamplerOutput],
output_indices_to_retain: List[int]) -> List[SamplerOutput]:
output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]:
"""
Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the
Expand All @@ -199,8 +202,8 @@ def _filter_model_output(
Args:
expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model.
output_indices_to_retain (List[int]): Indices of the model outputs
to retain.
output_indices_to_retain (torch.Tensor): Indices of the model
outputs to retain.
Returns:
List[SamplerOutput]: A list containing the filtered model
Expand Down

0 comments on commit f6084f6

Please sign in to comment.