Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Core] Add Guidance backend for guided decoding #10217

Open
wants to merge 34 commits into
base: main
Choose a base branch
from

Conversation

JC1DA
Copy link

@JC1DA JC1DA commented Nov 11, 2024

This pull request extends guided decoding capabilities

guidance backend supports regex, choice, json and grammar.

relevant: #5245

Usage

  • JSON Generation
from pydantic import BaseModel, ConfigDict

model = "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4"
llm = LLM(model=model)

class UserProfile(BaseModel):
    name: str
    age: int
    email: str

    model_config = ConfigDict(extra="forbid")

sampling_params = SamplingParams(
    temperature=0.0,
    top_p=0.95,
    max_tokens=512,
    guided_decoding=GuidedDecodingParams(
        json=UserProfile,
        backend="guidance",
    ),
)

outputs = llm.chat(
    messages=[
        [
            CustomChatCompletionMessageParam(
                role="system", content="You are a helpful assistant."
            ),
            CustomChatCompletionMessageParam(
                role="user",
                content="Tell me something about yourself (name, age, email) in JSON format.\n",
            ),
        ],
    ],
    sampling_params=[sampling_params],
)
  • Choices Generation
sampling_params = SamplingParams(
    temperature=0.0,
    top_p=0.95,
    max_tokens=512,
    guided_decoding=GuidedDecodingParams(
        choice=["3","4","5","6"],
        backend="guidance",
    ),
)

outputs = llm.chat(
    messages=[
        [
            CustomChatCompletionMessageParam(
                role="system", content="You are a 5 years-old helpful assistant."
            ),
            CustomChatCompletionMessageParam(
                role="user",
                content="How old are you?",
            ),
        ],
    ],
    sampling_params=[sampling_params],
)
  • Regex Generation via OpenAI Client
model = "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4"
client = OpenAI(
    base_url="http://localhost:8000/v1",
    api_key="NOKEY",
)

completion = client.chat.completions.create(
    model=model,
    messages=[
        {
            "role": "user",
            "content": "You are a 5 years-old helpful assistant.",
        },
        {
            "role": "user",
            "content": """How old are you?""",
        },
    ],
    extra_body={"guided_regex": "\\d+", "guided_decoding_backend": "guidance"}
)

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Nov 11, 2024
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @JC1DA for the great contribution!

A few other questions:

  • Presumably the parallelization speedup is due to the fact that the pytorch ops involved release the gil?
  • Were your outlines measurements also using the threadpool?
  • It would be good to also try with the latest outlines 0.1.x if possible which is apparently much faster than < 0.1. We would want to upgrade to that too in any case.

@@ -8,7 +8,7 @@ async def get_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]:
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar:
if guided_params.backend == 'outlines':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LMFE doesn't support grammar, we should retain the existing behaviour to fall back to a different backend in this case (perhaps it could now be guidance rather than outlines).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a check at the beginning to use outlines by default?

    if guided_params.grammar and guided_params.backend not in [
            'outlines', 'guidance'
    ]:
        guided_params.backend = 'outlines'

vllm/model_executor/layers/logits_processor.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/logits_processor.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/logits_processor.py Outdated Show resolved Hide resolved
vllm/model_executor/guided_decoding/guidance_decoding.py Outdated Show resolved Hide resolved
Comment on lines 170 to 172
mask = torch.tensor(mask,
dtype=logits.dtype,
device=logits.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the allocated mask tensor be reused between calls?

Copy link
Author

@JC1DA JC1DA Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill I have updated the code to reuse the logits variable. as we don't add thread pool into this PR anymore, it should work great with inplace ops.

@@ -19,6 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.6
outlines >= 0.0.43, < 0.1
guidance>=0.2rc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the requirements of guidance? Does it have compiled binaries for specific python versions or CPU architectures?
Maybe this could be an optional dependency to start with, like we do for many quantization backends

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @mgoin, guidance does have a fair number of dependencies, but we're mostly depending on the lower-level guidance layer here llguidance. llguidance is compiled for Python 3.9+ on manylinux/Mac OS/Windows. My understanding is that vLLM only supports Linux on Python 3.9+ too so I think we should be good there.

We can change this PR in the near future to just use llguidance (which has no other dependencies: https://github.com/microsoft/llguidance/blob/b5ca97b2562b720c1ff3f567bfa45956338a1864/pyproject.toml#L8). We just need to port one last function down from the Python guidance library into the Rust layer first :).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin we replaced guidance with llguidance which has no extra dependencies. Hope it is good enough to merge :)

@JC1DA
Copy link
Author

JC1DA commented Nov 14, 2024

Thanks @njhill for your quick review. Really appreciate it.

  • Presumably the parallelization speedup is due to the fact that the pytorch ops involved release the gil?

That's one reason, another one is the parser (llguidance) used in guidance was implemented in Rust, and it automatically releases GIL when called. So it would be more efficient to run guidance in parallel.

  • Were your outlines measurements also using the threadpool?

Yes, experiments were done using threadpool

  • It would be good to also try with the latest outlines 0.1.x if possible which is apparently much faster than < 0.1. We
    would want to upgrade to that too in any case.

I haven't tested outlines 0.1.x yet, just used the current version in VLLM. However, I am not focusing too much on the benchmark for this PR. The focus is to make guidance available as another guided decoding backend to VLLM's community so people can choose what's best for them. :)

@JC1DA
Copy link
Author

JC1DA commented Nov 14, 2024

I also figured out lm-format-enforcer is not thread-safe. It failed some tests when number of threads is larger than 1.
@njhill any suggestions for this?

@Harsha-Nori Harsha-Nori mentioned this pull request Nov 15, 2024
39 tasks
@JC1DA
Copy link
Author

JC1DA commented Nov 25, 2024

I also figured out lm-format-enforcer is not thread-safe. It failed some tests when number of threads is larger than 1. @njhill any suggestions for this?

Decided to rollback to single threaded version to not break lm-format-enforcer. The PR is coming with minimal changes to add llguidance as new logits processor.
Hope the current code is good for merging :) @njhill @mgoin

@JC1DA JC1DA closed this Nov 25, 2024
@JC1DA JC1DA reopened this Nov 25, 2024
Copy link

mergify bot commented Dec 3, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @JC1DA.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 3, 2024
Signed-off-by: Loc Huynh <[email protected]>
@mergify mergify bot removed the needs-rebase label Dec 3, 2024
@JC1DA
Copy link
Author

JC1DA commented Dec 3, 2024

Resolved conflict with newly merged xgrammar

Copy link

mergify bot commented Dec 3, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @JC1DA.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 3, 2024
@mergify mergify bot removed the needs-rebase label Dec 3, 2024
@JC1DA
Copy link
Author

JC1DA commented Dec 5, 2024

Resolved conflict with newly merged xgrammar

@njhill @mgoin

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants