Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add factories for logits_processors #38

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft

Conversation

maxdebayser
Copy link
Contributor

This is a proposed solution to the guided decoding crash problem.

In vLLM there are two requests that can be used to submit more than one sequence at the same time: the completion request in the legacy OpenAI API, and the generate batch request in the tgis gRPC API. In both cases the sampling params are validated only once and a single SamplingsParam object is shared between all sequences, even when they belong to different sequence groups. The SamplingsParam is mostly a data object, but it has a list of logits processors that are executable. If the logits processors are stateless there is no problem. However, the CFGFSM used by the CFGLogitsProcessor has an internal state that depends on the sequence generated so far and is updated at each iteration. This causes it to raise a KeyError and crash the asyncio event loop.

The first attempted solution added a seq_id parameter to the logits processor call so that it could manage state internally, but that changed the interface that is already used by some external libraries.

The solution proposed here is based on adding factories for stateful logits processors. The basic idea is:

  1. We add processors and factories to the same list so that they are in the correct order
  2. We add a logits_processors list to the SequenceGroupState object
  3. When the SequenceGroup is created, we iterate over the sampling_params.logits_processors
    and copy the logits_processors and call the factories to populate SequenceGroupState.logits_processors
  4. The LogitsProcessor(nn.Module) will iterate over the SequenceGroupState.logits_processors instead of
    the sampling_params.logits_processors

Here are some diagrams to illustrate the current code structure to better visualize the proposed changes:
vllm_sampling
vllm_seq_classes

The idea is quite simple, but the execution is a bit tricky due to the nature of async code in python where an async call can't call a non-async function that calls an async function. In the PR I tried to support both using LLMEngine directly as well as the AsyncLLMEngine used for serving.

@njhill, I was going to add support to return the processors to the factory, but I realized that it was a little bit more complicated because only the scheduler knows when the sequence is done. Maybe we can add a callback somewhere in the scheduler where we can add the deallocation call. Actually that might be required, because I realized that there is another hidden bug: when sequence are preempted with the recompute policy, that makes the state of the logits processor invalid.

This allows vllm to support stateful LPs that must be
unique for each sequence.

Signed-off-by: Max de Bayser <[email protected]>
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @maxdebayser!

I feel like this could be simplified a bit, basically just change get_lm_format_enforcer_guided_decoding_logits_processor and get_outlines_guided_decoding_logits_processor to return factories instead of LPs

vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/entrypoints/openai/serving_chat.py Outdated Show resolved Hide resolved
vllm/entrypoints/openai/serving_completion.py Outdated Show resolved Hide resolved
vllm/sequence.py Outdated Show resolved Hide resolved
vllm/model_executor/guided_decoding/outlines_decoding.py Outdated Show resolved Hide resolved
vllm/model_executor/guided_decoding/__init__.py Outdated Show resolved Hide resolved
vllm/model_executor/guided_decoding/__init__.py Outdated Show resolved Hide resolved
vllm/model_executor/sampling_metadata.py Outdated Show resolved Hide resolved
To reduce the lines of the diff

Signed-off-by: Max de Bayser <[email protected]>
@maxdebayser
Copy link
Contributor Author

Thanks for the suggestions, @njhill . I didn't want to change too much of the original code so I ended up going overboard with the async cleverness. It's much simpler now.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @maxdebayser this looks much better!

The other thing we may need to look at is what to do when sequences are forked. I think this only applies to beam search. Is deep-copying the LPs the right thing to do? Could there be problems deep-copying arbitrary LPs?

vllm/sampling_params.py Outdated Show resolved Hide resolved
vllm/model_executor/guided_decoding/__init__.py Outdated Show resolved Hide resolved
vllm/sampling_params.py Outdated Show resolved Hide resolved
@njhill
Copy link
Member

njhill commented Jun 6, 2024

@maxdebayser after addressing the simple comments above (not necessarily the pooling thing yet), maybe you could open an upstream PR? Then we can continue the discussions with others...

Signed-off-by: Max de Bayser <[email protected]>
@maxdebayser
Copy link
Contributor Author

Here's the upstream PR: vllm-project/vllm#5329

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants