diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py new file mode 100644 index 0000000000000..9d3ef3c67d3dc --- /dev/null +++ b/tests/samplers/test_rejection_sampler.py @@ -0,0 +1,392 @@ +"""Tests for rejection sampling.""" +import pytest +from typing import List, Tuple + +import torch +import torch.nn.functional as F + +from vllm.model_executor.utils import set_random_seed + +from vllm.model_executor.layers.rejection_sampler import RejectionSampler + + +def mock_causal_accepted_tensor( + k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor: + """Generate an "accepted" tensor which should yield causally-accepted tokens + up to last accepted indices. + + Tokens after last_accepted_indices+1 may also be accepted, although they + will not be causally accepted. + """ + batch_size = last_accepted_indices.shape[0] + + accepted = (torch.arange(k).expand(batch_size, k) <= + last_accepted_indices.unsqueeze(-1).broadcast_to( + batch_size, k)).to(device="cuda") + + # Sprinkle accepted values after the contiguous initial accepted values. + # This replicates the behavior of rejection sampling, which may "accept" + # a token that cannot be accepted because of causality. + sprinkle_candidates = ( + torch.arange(k).expand(batch_size, k) > + last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1) + sprinkle = torch.rand(batch_size, k, device="cuda") > 0.5 + accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates] + return accepted + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize( + "which_tokens_accepted", + ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"]) +@torch.inference_mode() +def test_correct_output_format(which_tokens_accepted: str, seed: int): + """Verify the output has correct format given predetermined accepted matrix. + """ + set_random_seed(seed) + + batch_size = 10 + k = 5 + vocab_size = 3000 + + if which_tokens_accepted == "all_tokens_accepted": + accepted = mock_causal_accepted_tensor( + k, -1 + k * torch.ones((batch_size, ), dtype=torch.long)) + elif which_tokens_accepted == "no_tokens_accepted": + accepted = mock_causal_accepted_tensor( + k, -torch.ones((batch_size, ), dtype=torch.long)) + elif which_tokens_accepted == "some_tokens_accepted": + last_accepted_indices = torch.randint(low=-1, + high=k, + size=(batch_size, )) + accepted = mock_causal_accepted_tensor(k, last_accepted_indices) + else: + raise AssertionError() + + recovered_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64, + device="cuda") + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64, + device="cuda") + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64, + device="cuda") + + rejection_sampler = RejectionSampler() + rejection_sampler.init_gpu_tensors(rank=0) + output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access + accepted, + recovered_token_ids, + draft_token_ids, + bonus_token_ids, + ) + + if which_tokens_accepted == "all_tokens_accepted": + # Expect all tokens to be equal to draft tokens. + assert torch.equal(output_token_ids[:, :-1], draft_token_ids) + + # Expect all bonus tokens to be included. + assert torch.equal(output_token_ids[:, -1:], bonus_token_ids) + elif which_tokens_accepted == "no_tokens_accepted": + # Expect first token to be equal to recovered tokens. + assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0]) + + # Expect everything else to be -1. + assert torch.equal(output_token_ids[:, 1:], + torch.ones_like(output_token_ids[:, 1:]) * -1) + elif which_tokens_accepted == "some_tokens_accepted": + recovered_plus_bonus = torch.cat( + (recovered_token_ids, bonus_token_ids), dim=-1) + # Assert first rejected token is a recovered token or bonus token. + assert torch.equal( + recovered_plus_bonus[torch.arange(0, batch_size), + last_accepted_indices + 1], + output_token_ids[torch.arange(0, batch_size), + last_accepted_indices + 1]) + + # Assert every subsequent token is -1. + subsequent_mask = torch.arange(0, k + 1).expand( + batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1) + assert torch.all(output_token_ids[subsequent_mask] == -1) + + +@pytest.mark.parametrize("k", list(range(1, 6))) +@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) +@pytest.mark.parametrize("batch_size", list(range(1, 32))) +@torch.inference_mode() +def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int): + rejection_sampler = RejectionSampler() + rejection_sampler.init_gpu_tensors(rank=0) + + draft_probs = torch.rand(batch_size, + k, + vocab_size, + dtype=torch.float32, + device="cuda") + target_probs = torch.rand(batch_size, + k, + vocab_size, + dtype=torch.float32, + device="cuda") + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64, + device="cuda") + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64, + device="cuda") + + rejection_sampler(target_probs, bonus_token_ids, draft_probs, + draft_token_ids) + + +@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) +@pytest.mark.parametrize("which_token_ids", + ["bonus_token_ids", "draft_token_ids"]) +@torch.inference_mode() +def test_raises_when_vocab_oob(above_or_below_vocab_range: str, + which_token_ids: str): + k = 3 + batch_size = 5 + vocab_size = 30_000 + + rejection_sampler = RejectionSampler(strict_mode=True) + rejection_sampler.init_gpu_tensors(rank=0) + + draft_probs = torch.rand(batch_size, + k, + vocab_size, + dtype=torch.float32, + device="cuda") + target_probs = torch.rand(batch_size, + k, + vocab_size, + dtype=torch.float32, + device="cuda") + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64, + device="cuda") + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64, + device="cuda") + + oob_token_ids = None + if which_token_ids == "bonus_token_ids": + oob_token_ids = bonus_token_ids + elif which_token_ids == "draft_token_ids": + oob_token_ids = draft_token_ids + else: + raise AssertionError() + + if above_or_below_vocab_range == "above": + rogue_token_id = vocab_size + 1 + elif above_or_below_vocab_range == "below": + rogue_token_id = -1 + else: + raise AssertionError() + + oob_token_ids[0][0] = rogue_token_id + + with pytest.raises(AssertionError): + rejection_sampler(target_probs, bonus_token_ids, draft_probs, + draft_token_ids) + + +@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False]) +@pytest.mark.parametrize("seed", list(range(5))) +@torch.inference_mode() +def test_rejection_sampling_approximates_target_distribution( + seed: int, draft_and_target_probs_equal: bool): + """Verify rejection sampling approximates target distribution, + despite sampling from a potentially distinct draft distribution. + + This is done by first creating a random target probability + distribution and a random draft probability distribution. We then + sample token ids from the rejection sampler using these draft + and target distributions. The samples are used to estimate + the output probability distribution, which we expect to approximate + the target distribution. + + A basic distance metric is used to determine similarity between + distributions. + + We expect that as we increase the number of samples, + the distance between the observed distribution and the target + distribution decreases. To measure this, we compare the distance + of the observed distribution against both the target distribution + and a uniform random distribution. We expect the distance between + the observed distribution and the target distribution to improve + much more than the distance improvement between the observed + distribution and the random distribution. + + When draft_and_target_probs_equal=True, the draft and target + probabilities are exactly equal. Rejection sampling should + still work without any NaNs or exceptions. + """ + set_random_seed(seed) + + helper = _CorrectnessTestHelper( + vocab_size=10, + rejection_sampler=RejectionSampler(), + ) + + draft_probs, target_probs, reference_probs = helper.generate_probs_for_test( + draft_and_target_probs_equal) + + sample_sizes = [10, 100, 1_000, 10_000, 100_000] + distance_wrt_reference = [] + distance_wrt_target = [] + + for num_samples in sample_sizes: + (reference_vs_rejsample_dist, + target_vs_rejsample_dist) = helper.run_and_compare_distributions( + draft_probs, + target_probs, + reference_probs, + num_samples, + ) + + distance_wrt_reference.append(reference_vs_rejsample_dist) + distance_wrt_target.append(target_vs_rejsample_dist) + + relative_change_in_distance_wrt_target = get_ratio_first_to_last( + distance_wrt_target) + relative_change_in_distance_wrt_reference = get_ratio_first_to_last( + distance_wrt_reference) + + print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} " + f"{reference_vs_rejsample_dist=:.05f}") + print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} " + f"{relative_change_in_distance_wrt_reference=:.02f}") + + relative_change_in_distance_wrt_target = get_ratio_first_to_last( + distance_wrt_target) + relative_change_in_distance_wrt_reference = get_ratio_first_to_last( + distance_wrt_reference) + + expected_improvement_multiplier = 20 + assert (relative_change_in_distance_wrt_target > + relative_change_in_distance_wrt_reference * + expected_improvement_multiplier) + + +def get_ratio_first_to_last(elements: List[float]) -> float: + return elements[0] / elements[-1] + + +class _CorrectnessTestHelper: + """Class that packages together logic required for the unit-level + rejection sampling correctness test. + """ + + def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler): + self.rejection_sampler = rejection_sampler + self.vocab_size = vocab_size + self.vocab_range = (0, vocab_size) + + self.rejection_sampler.init_gpu_tensors(rank=0) + + # Keep test simple, use k=1 + self.k = 1 + + # Bonus tokens not used, but rejection sampler requires + # correct shape. + self.num_bonus_tokens = 1 + + def generate_probs_for_test( + self, draft_and_target_probs_equal: bool + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + draft_probs, target_probs = [ + F.softmax( + torch.rand(self.vocab_size, dtype=torch.float32), + dim=-1, + ) for _ in range(2) + ] + + num_reference_probs = 100 + reference_probs = F.softmax( + torch.rand(num_reference_probs, + self.vocab_size, + dtype=torch.float32), + dim=-1, + ) + + if draft_and_target_probs_equal: + target_probs = draft_probs.clone() + + return draft_probs, target_probs, reference_probs + + def run_and_compare_distributions(self, draft_probs: torch.Tensor, + target_probs: torch.Tensor, + reference_probs: torch.Tensor, + num_samples: int) -> Tuple[float, float]: + # Sample using rejection sampling. + rej_sample_probs = self._estimate_rejection_sampling_pdf( + draft_probs, target_probs, num_samples) + + # Average distance from reference probs. + reference_vs_rejsample_dist = torch.dist( + reference_probs, + rej_sample_probs).item() / reference_probs.shape[0] + target_vs_rejsample_dist = torch.dist(target_probs, + rej_sample_probs).item() + + return reference_vs_rejsample_dist, target_vs_rejsample_dist + + def _estimate_rejection_sampling_pdf( + self, + draft_probs: torch.Tensor, + target_probs: torch.Tensor, + num_samples: int, + ) -> torch.Tensor: + # Repeat draft probs num_samples times. + draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat( + num_samples, 1, 1) + + # Repeat target probs num_samples * k times. + # Rejection sampler requires bonus token probs, but they aren't used. + target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat( + num_samples, self.k, 1) + + # Randomly sample draft token ids from draft probs. + draft_token_ids = torch.multinomial(draft_probs[:, 0, :], + num_samples=1, + replacement=True).reshape( + num_samples, self.k) + + # Bonus tokens not used but required. + bonus_token_ids = torch.zeros((1, self.num_bonus_tokens), + dtype=torch.int64, + device="cuda").repeat(num_samples, 1) + + # Get output tokens via rejection sampling. + output_token_ids = self.rejection_sampler(target_probs.to("cuda"), + bonus_token_ids.to("cuda"), + draft_probs.to("cuda"), + draft_token_ids.to("cuda")) + + # Remove bonus tokens + output_token_ids = output_token_ids[:, :-1].flatten() + + # Estimate probability density function + hist = torch.histogram(output_token_ids.to(dtype=torch.float, + device="cpu"), + bins=self.vocab_size, + range=self.vocab_range, + density=True) + + return hist.hist diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py new file mode 100644 index 0000000000000..3e1cfc783b8ef --- /dev/null +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -0,0 +1,392 @@ +from typing import Tuple, Optional +from functools import cached_property + +import torch +import torch.nn as nn +import torch.jit + + +class RejectionSampler(nn.Module): + """Apply modified rejection sampling as described in "Accelerating Large + Language Model Decoding with Speculative Sampling" + https://arxiv.org/pdf/2302.01318.pdf. + """ + + def __init__(self, strict_mode: bool = False): + """Create a rejection sampler. + + Args: + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + """ + super().__init__() + self.probs_dtype = torch.float32 + self.token_id_dtype = torch.int64 + self._strict_mode = strict_mode + + # NOTE: A "bonus token" is accepted iff all proposal tokens are + # accepted. There is always only one possible bonus token. We store this + # value in a variable for readability. + self._num_bonus_tokens = 1 + + self.num_accepted_tokens: Optional[torch.Tensor] = None + self.num_emitted_tokens: Optional[torch.Tensor] = None + self.num_draft_tokens: int = 0 + + def init_gpu_tensors(self, rank: int) -> None: + assert self.num_accepted_tokens is None + device = f"cuda:{rank}" + self.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + self.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + + def forward( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> torch.Tensor: + """Sample token ids using rejection sampling. This accepts or rejects + tokens proposed by the draft model using the probability of each token + according to the draft and target models. + + In the worst case where all draft tokens are rejected, it is guaranteed + one correct token will be emitted. + + In the case where all draft tokens are accepted, a bonus token will be + accepted as its cheap to have the target model score this speculative + sequence. + + Args: + target_probs: The probability distribution over token ids given + context according to the target model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + bonus_token_ids: The "bonus" token ids that are accepted iff all + speculative tokens in a sequence are accepted. + shape = [batch_size, num_bonus_tokens] + + draft_probs: The probability distribution over token ids given + context according to the draft model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + draft_token_ids: The token ids that were sampled from the draft + probabilities. + shape = [batch_size, num_speculative_tokens] + + Returns: + output_token_ids: The token ids sampled via rejection sampling, + or -1 if unable to sample a token because the previous token + was rejected. + shape = [batch_size, num_speculative_tokens + num_bonus_tokens] + """ + # Only perform shape/dtype/device checking in strict mode, as it adds + # overhead. + if self._strict_mode: + self._raise_if_incorrect_shape(target_probs, bonus_token_ids, + draft_probs, draft_token_ids) + self._raise_if_incorrect_dtype(target_probs, bonus_token_ids, + draft_probs, draft_token_ids) + self._raise_if_inconsistent_device(target_probs, bonus_token_ids, + draft_probs, draft_token_ids) + self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], + bonus_token_ids, + draft_token_ids) + + accepted, recovered_token_ids = self._batch_modified_rejection_sampling( + target_probs, + draft_probs, + draft_token_ids, + ) + + output_token_ids = self._create_output( + accepted, + recovered_token_ids, + draft_token_ids, + bonus_token_ids, + ) + return output_token_ids + + def _batch_modified_rejection_sampling( + self, + target_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_token_ids: torch.Tensor, # [batch_size, k] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Perform modified rejection sampling on each sequence. + + Returns: + A tuple of two tensors: + 0: A bool tensor of which tokens in each sequence is accepted. + shape = [batch_size, k] + 1: Token ids sampled from a recovered distribution, to be used + when a token is rejected. + shape = [batch_size, k] + """ + + batch_size, k, vocab_size = draft_probs.shape + + # shape [batch_size, k] + accepted = self._get_accepted(target_probs, draft_probs, + draft_token_ids) + + recovered_probs = self._get_recovered_probs( + target_probs, draft_probs).reshape(batch_size * k, vocab_size) + + recovered_token_ids = _multinomial(recovered_probs, + num_samples=1).reshape( + batch_size, k) + return accepted, recovered_token_ids + + def _get_accepted( + self, + target_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_token_ids: torch.Tensor, # [batch_size, k] + ) -> torch.Tensor: + r"""Create bool matrix over the proposed draft tokens. If + True, then a token can be accepted, else it should be + rejected. + + Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of + :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according + to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the + same conditional probability according to the draft model, the token + is accepted with probability: + + .. math:: + \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)} + {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right) + + This implementation does not apply causality. When using the output, + if a token is rejected, subsequent tokens should not be used. + + Returns a bool tensor of shape [batch_size, k] specifying which tokens + are accepted. + """ + batch_size, k, _ = draft_probs.shape + batch_indices = torch.arange(batch_size, + device=target_probs.device)[:, None] + probs_indicies = torch.arange(k, device=target_probs.device) + + # shape [batch_size, k] + selected_draft_probs = draft_probs[batch_indices, probs_indicies, + draft_token_ids] + + # shape [batch_size, k] + selected_target_probs = target_probs[batch_indices, probs_indicies, + draft_token_ids] + + uniform_rand = torch.rand(batch_size, + k, + dtype=self.probs_dtype, + device=target_probs.device) + capped_ratio = torch.minimum( + selected_target_probs / selected_draft_probs, + torch.full((1, ), 1, device=target_probs.device)) + accepted = uniform_rand < capped_ratio + + return accepted + + def _get_recovered_probs( + self, + target_probs: torch.Tensor, # [k, vocab_size] + draft_probs: torch.Tensor, # [k, vocab_size] + ) -> torch.Tensor: + r"""Create a probability distribution for each proposed token which can + be sampled if the proposed token is rejected. + + When this routine is applied sequentially, the true distribution of the + target model is recovered (within hardware numerics). + + The probability distribution used in this rejection case is constructed + as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of + :math:`x` given context :math:`x_1, \dots, x_n` according to the target + model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability + according to the draft model: + + .. math:: + x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+ + + where :math:`(f(x))_+` is defined as: + + .. math:: + (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))} + + See https://github.com/vllm-project/vllm/pull/2336 for a visualization + of the draft, target, and recovered probability distributions. + + Returns a tensor of shape [batch_size, k, vocab_size]. + + Note: This batches operations on GPU and thus constructs the recovered + distribution for all tokens, even if they are accepted. This causes + division-by-zero errors, so we use self._smallest_positive_value to + avoid that. This introduces some drift to the distribution. + """ + _, k, _ = draft_probs.shape + + # shape [batch_size, k, vocab_size] + difference = target_probs - draft_probs + + # TODO(cade): Can we use logprobs instead of probs, and avoid the + # division-by-zero errors without introducing distribution drift? + + # shape [batch_size, k, vocab_size] + f = torch.clamp(difference, min=self._smallest_positive_value) + + # shape [batch_size, k, vocab_size] + recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1) + + return recovered_probs + + @cached_property + def _smallest_positive_value(self) -> float: + """Return the smallest positive value representable by the probs dtype. + This value is used when constructing a distribution from which to sample + recovered tokens in the first rejection case. + + See _get_recovered_probs for more details + + Note that this isn't actually the smallest positive value representable + by float32, but the smallest positive normal value. + See https://en.wikipedia.org/wiki/Subnormal_number for more information. + """ + return torch.finfo(self.probs_dtype).tiny + + def _create_output( + self, + accepted: torch.Tensor, # [batch_size, k] + recovered_token_ids: torch.Tensor, # [batch_size, k] + draft_token_ids: torch.Tensor, # [batch_size, k] + bonus_token_ids: torch.Tensor, # [batch_size] + ) -> torch.Tensor: + """Format output. Returns a matrix of token ids. When + a token is rejected via rejection sampling, all subsequent + token ids are set to -1 for the sequence. + + shape = [batch_size, k + num_bonus_tokens] + """ + bonus_token_ids = bonus_token_ids.squeeze() + batch_size, k = recovered_token_ids.shape + + # Determine the index of the first False value for each row. + limits = (accepted == 0).max(1).indices + limits[~(accepted == 0).any(1)] = k + + # Create masks using the indices. + indices = torch.arange(k, device=accepted.device).unsqueeze(0) + accepted_mask = indices < limits.unsqueeze(1) + after_false_mask = indices == limits.unsqueeze(1) + + # Create an extended output tensor + output_with_bonus_tokens = -torch.ones( + (batch_size, k + self._num_bonus_tokens), + dtype=self.token_id_dtype, + device=accepted.device) + output = output_with_bonus_tokens[:, :k] + + # Fill in the first k columns of the output tensor using masks and data + # tensors. + output[:, :k] = torch.where(accepted_mask, draft_token_ids, + -torch.ones_like(draft_token_ids)) + + # Fill the last column. + # We check output directly as accepted may have True values inconsistent + # with causal acceptance. + output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, + bonus_token_ids, -1) + + # Fill the recovered token ids. + output.mul_(~after_false_mask).add_( + recovered_token_ids.mul(after_false_mask)) + + self.num_accepted_tokens += accepted.sum() + self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() + self.num_draft_tokens += batch_size * k + + return output_with_bonus_tokens + + def _raise_if_incorrect_shape( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> None: + (target_batch_size, num_target_probs, + target_vocab_size) = target_probs.shape + bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape + draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape + draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape + + assert draft_batch_size == target_batch_size + assert num_draft_probs == num_target_probs + assert (draft_vocab_size == target_vocab_size + ), f"{draft_vocab_size=} {target_vocab_size=}" + + assert draft_token_ids_batch_size == draft_batch_size + assert num_draft_token_ids == num_draft_probs + + assert bonus_batch_size == target_batch_size + assert num_bonus_tokens == self._num_bonus_tokens + + def _raise_if_incorrect_dtype( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> None: + assert all(probs.dtype == self.probs_dtype + for probs in [target_probs, draft_probs]) + assert all(token_ids.dtype == self.token_id_dtype + for token_ids in [bonus_token_ids, draft_token_ids]) + + def _raise_if_inconsistent_device( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> None: + devices = [ + t.device for t in + [target_probs, bonus_token_ids, draft_probs, draft_token_ids] + ] + assert all([devices[0] == device for device in devices]) + + def _raise_if_out_of_bounds_vocab( + self, + vocab_size: int, + bonus_token_ids: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> None: + assert torch.all(bonus_token_ids < vocab_size) + assert torch.all(bonus_token_ids >= 0) + assert torch.all(draft_token_ids < vocab_size) + assert torch.all(draft_token_ids >= 0) + + +# torch.multinomial forces a GPU<->CPU sync. +# Therefore, we use an optimized implementation instead that skips the sync. +# Note that we always sample with replacement. +# probs will be modified in place, but this is fine, as we pass +# in a copy already. +@torch.jit.script +def _multinomial( + probs: torch.Tensor, + num_samples: int, +) -> torch.Tensor: + if num_samples > 1: + # This is equivalent to torch.repeat_interleaved (which also + # forces a GPU<->CPU sync). + probs = probs[:, None, :].expand(probs.shape[0], num_samples, + probs.shape[1]).contiguous().view( + -1, probs.shape[1]) + q = torch.empty_like(probs).exponential_(1.0) + return probs.div_(q).argmax(dim=1).view(-1, num_samples)