diff --git a/outlines/integrations/vllm.py b/outlines/integrations/vllm.py index 6ed56d71b..595f246a0 100644 --- a/outlines/integrations/vllm.py +++ b/outlines/integrations/vllm.py @@ -27,7 +27,7 @@ import math from collections import defaultdict -from typing import TYPE_CHECKING, DefaultDict, List, Optional, Type, Union +from typing import TYPE_CHECKING, DefaultDict, List, Optional, Tuple, Type, Union import torch from pydantic import BaseModel @@ -101,11 +101,30 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: # Initialize the FSM state dictionary if the input_ids are empty, as this means # that the input_ids are the first tokens of the sequence. if len(input_ids) > 0: - last_token = input_ids[-1] - last_seq_id = hash(tuple(input_ids[:-1])) - self._fsm_state[seq_id] = self.fsm.get_next_state( - state=self._fsm_state[last_seq_id], token_id=last_token - ) + # create stack of sequence offsets to proces + # typically (without fast-forward) only the most recent single sequence offset + # is processed, and the FSM only advances by one state + + token_seq_transitions: List[Tuple[int, int, int]] = [] + + next_seq_id = seq_id + for offset in range(1, len(input_ids) + 1): + prev_seq = input_ids[:-offset] + prev_seq_id = hash(tuple(prev_seq)) + token_seq_transitions.append( + (input_ids[-offset], prev_seq_id, next_seq_id) + ) + if prev_seq_id in self._fsm_state: + break + next_seq_id = prev_seq_id + else: + raise RuntimeError("Failed to find a prior processed sequence in FSM") + + # apply all unfulfilled fsm state transitions + for token, prev_seq_id, next_seq_id in reversed(token_seq_transitions): + self._fsm_state[next_seq_id] = self.fsm.get_next_state( + state=self._fsm_state[prev_seq_id], token_id=token + ) allowed_tokens = self.fsm.get_next_instruction( state=self._fsm_state[seq_id] diff --git a/tests/generate/test_integration_vllm.py b/tests/generate/test_integration_vllm.py index 4634bc839..032790bf2 100644 --- a/tests/generate/test_integration_vllm.py +++ b/tests/generate/test_integration_vllm.py @@ -1,5 +1,6 @@ import datetime import re +from unittest.mock import MagicMock, patch import pytest import torch @@ -10,6 +11,7 @@ import outlines.grammars as grammars import outlines.models as models import outlines.samplers as samplers +from outlines.integrations.vllm import RegexLogitsProcessor pytestmark = pytest.mark.skipif( not torch.cuda.is_available(), reason="vLLM models can only be run on GPU." @@ -237,3 +239,32 @@ def test_vllm_cfg(model): prompt = "<|im_start|>user\nOutput a short and valid JSON object with two keys.<|im_end|>\n><|im_start|>assistant\n" result = generate.cfg(model, grammars.arithmetic)(prompt, seed=11) assert isinstance(result, str) + + +def test_multiple_tokens_fast_forward_regex_logits_processor(): + """Reproduces https://github.com/outlines-dev/outlines/issues/855""" + + # Create a RegexLogitsProcessor and call it with non-incremental (fast-forward) sequences + + class MockTokenizer: + vocabulary = {"1": 1, "a": 2, "eos": 3} + special_tokens = {"eos"} + eos_token_id = 3 + + def convert_token_to_string(self, token): + return token + + mock_scores = torch.tensor([0.5, 0.5, 0.5, 0.5]) + + with patch( + "outlines.integrations.vllm.adapt_tokenizer", return_value=MockTokenizer() + ): + lp = RegexLogitsProcessor(r".*", MagicMock()) + + # patch DefaultDict -> dict, ensure KeyError for unseen non-initial sequences + lp._fsm_state = {hash(tuple()): 0} + + lp([], mock_scores) + lp([1], mock_scores) + lp([1, 2], mock_scores) + lp([1, 2, 3, 4], mock_scores)