Skip to content

Commit

Permalink
[Speculative decoding] [Multi-Step] decouple should_modify_greedy_pro…
Browse files Browse the repository at this point in the history
…bs_inplace (vllm-project#6971)
  • Loading branch information
SolitaryThinker authored Aug 9, 2024
1 parent 99b4cf5 commit 57b7be0
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 3 deletions.
27 changes: 26 additions & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools
import random
from typing import Dict, List, Optional, Tuple
from unittest.mock import patch
from unittest.mock import Mock, patch

import pytest
import torch
Expand Down Expand Up @@ -703,3 +703,28 @@ def test_sampling_params(sampling_params: List[SamplingParams]):

assert tokens1[0] == tokens2[1]
assert tokens1[1] == tokens2[0]


@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_include_gpu_probs_tensor(device: str):
set_random_seed(42)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)
sampler.include_gpu_probs_tensor = True
sampler.should_modify_greedy_probs_inplace = False

sampling_params = SamplingParams(temperature=0)

mock_inplace = Mock()
with patch(
"vllm.model_executor.layers.sampler._modify_greedy_probs_inplace",
mock_inplace):

sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
mock_inplace.assert_not_called()

assert sampler_output.sampled_token_probs is not None
assert sampler_output.logprobs is not None
assert sampler_output.sampled_token_ids is not None
4 changes: 4 additions & 0 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,10 @@ def org_vocab_size(self):
def include_gpu_probs_tensor(self):
return self.base_layer.include_gpu_probs_tensor

@property
def should_modify_greedy_probs_inplace(self):
return self.base_layer.should_modify_greedy_probs_inplace

def create_lora_weights(
self,
max_loras: int,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self):
# containing the sampled token ids and probabilities. This is used by
# speculative decoding.
self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False

def _init_sampling_tensors(
self,
Expand Down Expand Up @@ -177,8 +178,7 @@ def _should_modify_greedy_probs_inplace(self) -> bool:
This is used by speculative decoding, which requires that the sampling
method be encoded into the probability distribution.
"""
# Modify greedy probs if include_gpu_probs_tensor is set.
return self.include_gpu_probs_tensor
return self.should_modify_greedy_probs_inplace


def _get_bin_counts_and_mask(
Expand Down
3 changes: 3 additions & 0 deletions vllm/spec_decode/medusa_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def init_device(self):
def set_include_gpu_probs_tensor(self):
pass

def set_should_modify_greedy_probs_inplace(self):
pass

@torch.inference_mode()
def sampler_output(
self,
Expand Down
4 changes: 4 additions & 0 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def set_include_gpu_probs_tensor(self) -> None:
# Need include_gpu_probs_tensor for MultiStepWorker
self.model_runner.model.sampler.include_gpu_probs_tensor = True

def set_should_modify_greedy_probs_inplace(self) -> None:
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
True)

@torch.inference_mode()
def sampler_output(
self,
Expand Down
4 changes: 4 additions & 0 deletions vllm/spec_decode/proposer_worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def set_include_gpu_probs_tensor(self) -> None:
"""Implementation optional"""
pass

def set_should_modify_greedy_probs_inplace(self) -> None:
"""Implementation optional"""
pass


class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
"""Proposer worker which does not use a model with kvcache"""
Expand Down
6 changes: 6 additions & 0 deletions vllm/spec_decode/smaller_tp_proposer_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def set_include_gpu_probs_tensor(self) -> None:
# Need include_gpu_probs_tensor for multi_step_worker
self._worker.set_include_gpu_probs_tensor()

def set_should_modify_greedy_probs_inplace(self) -> None:
if self._is_dummy:
return

self._worker.set_should_modify_greedy_probs_inplace()

def load_model(self) -> None:
if self._is_dummy:
return
Expand Down
3 changes: 3 additions & 0 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,10 @@ def _configure_model_sampler_for_spec_decode(self):
"""
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
) = True
(self.scorer_worker.model_runner.model.sampler.
should_modify_greedy_probs_inplace) = True
self.proposer_worker.set_include_gpu_probs_tensor()
self.proposer_worker.set_should_modify_greedy_probs_inplace()

def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of cache blocks to use.
Expand Down

0 comments on commit 57b7be0

Please sign in to comment.