Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aaa #19

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

aaa #19

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions outlines/integrations/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import math
from collections import defaultdict
from typing import TYPE_CHECKING, DefaultDict, List, Optional, Type, Union
from typing import TYPE_CHECKING, DefaultDict, List, Optional, Tuple, Type, Union

import torch
from pydantic import BaseModel
Expand Down Expand Up @@ -101,11 +101,30 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
# Initialize the FSM state dictionary if the input_ids are empty, as this means
# that the input_ids are the first tokens of the sequence.
if len(input_ids) > 0:
last_token = input_ids[-1]
last_seq_id = hash(tuple(input_ids[:-1]))
self._fsm_state[seq_id] = self.fsm.get_next_state(
state=self._fsm_state[last_seq_id], token_id=last_token
)
# create stack of sequence offsets to proces
# typically (without fast-forward) only the most recent single sequence offset
# is processed, and the FSM only advances by one state

token_seq_transitions: List[Tuple[int, int, int]] = []

next_seq_id = seq_id
for offset in range(1, len(input_ids) + 1):
prev_seq = input_ids[:-offset]
prev_seq_id = hash(tuple(prev_seq))
token_seq_transitions.append(
(input_ids[-offset], prev_seq_id, next_seq_id)
)
if prev_seq_id in self._fsm_state:
break
next_seq_id = prev_seq_id
else:
raise RuntimeError("Failed to find a prior processed sequence in FSM")

# apply all unfulfilled fsm state transitions
for token, prev_seq_id, next_seq_id in reversed(token_seq_transitions):
self._fsm_state[next_seq_id] = self.fsm.get_next_state(
state=self._fsm_state[prev_seq_id], token_id=token
)

allowed_tokens = self.fsm.get_next_instruction(
state=self._fsm_state[seq_id]
Expand Down
31 changes: 31 additions & 0 deletions tests/generate/test_integration_vllm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import re
from unittest.mock import MagicMock, patch

import pytest
import torch
Expand All @@ -10,6 +11,7 @@
import outlines.grammars as grammars
import outlines.models as models
import outlines.samplers as samplers
from outlines.integrations.vllm import RegexLogitsProcessor

pytestmark = pytest.mark.skipif(
not torch.cuda.is_available(), reason="vLLM models can only be run on GPU."
Expand Down Expand Up @@ -237,3 +239,32 @@ def test_vllm_cfg(model):
prompt = "<|im_start|>user\nOutput a short and valid JSON object with two keys.<|im_end|>\n><|im_start|>assistant\n"
result = generate.cfg(model, grammars.arithmetic)(prompt, seed=11)
assert isinstance(result, str)


def test_multiple_tokens_fast_forward_regex_logits_processor():
"""Reproduces https://github.com/outlines-dev/outlines/issues/855"""

# Create a RegexLogitsProcessor and call it with non-incremental (fast-forward) sequences

class MockTokenizer:
vocabulary = {"1": 1, "a": 2, "eos": 3}
special_tokens = {"eos"}
eos_token_id = 3

def convert_token_to_string(self, token):
return token

mock_scores = torch.tensor([0.5, 0.5, 0.5, 0.5])

with patch(
"outlines.integrations.vllm.adapt_tokenizer", return_value=MockTokenizer()
):
lp = RegexLogitsProcessor(r".*", MagicMock())

# patch DefaultDict -> dict, ensure KeyError for unseen non-initial sequences
lp._fsm_state = {hash(tuple()): 0}

lp([], mock_scores)
lp([1], mock_scores)
lp([1, 2], mock_scores)
lp([1, 2, 3, 4], mock_scores)
Loading