-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add factories for logits_processors #38
base: main
Are you sure you want to change the base?
Conversation
This allows vllm to support stateful LPs that must be unique for each sequence. Signed-off-by: Max de Bayser <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @maxdebayser!
I feel like this could be simplified a bit, basically just change get_lm_format_enforcer_guided_decoding_logits_processor
and get_outlines_guided_decoding_logits_processor
to return factories instead of LPs
Signed-off-by: Max de Bayser <[email protected]>
To reduce the lines of the diff Signed-off-by: Max de Bayser <[email protected]>
Thanks for the suggestions, @njhill . I didn't want to change too much of the original code so I ended up going overboard with the async cleverness. It's much simpler now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @maxdebayser this looks much better!
The other thing we may need to look at is what to do when sequences are forked. I think this only applies to beam search. Is deep-copying the LPs the right thing to do? Could there be problems deep-copying arbitrary LPs?
@maxdebayser after addressing the simple comments above (not necessarily the pooling thing yet), maybe you could open an upstream PR? Then we can continue the discussions with others... |
Signed-off-by: Max de Bayser <[email protected]>
Here's the upstream PR: vllm-project/vllm#5329 |
This is a proposed solution to the guided decoding crash problem.
In vLLM there are two requests that can be used to submit more than one sequence at the same time: the completion request in the legacy OpenAI API, and the generate batch request in the tgis gRPC API. In both cases the sampling params are validated only once and a single SamplingsParam object is shared between all sequences, even when they belong to different sequence groups. The SamplingsParam is mostly a data object, but it has a list of logits processors that are executable. If the logits processors are stateless there is no problem. However, the
CFGFSM
used by theCFGLogitsProcessor
has an internal state that depends on the sequence generated so far and is updated at each iteration. This causes it to raise a KeyError and crash the asyncio event loop.The first attempted solution added a seq_id parameter to the logits processor call so that it could manage state internally, but that changed the interface that is already used by some external libraries.
The solution proposed here is based on adding factories for stateful logits processors. The basic idea is:
and copy the logits_processors and call the factories to populate SequenceGroupState.logits_processors
the sampling_params.logits_processors
Here are some diagrams to illustrate the current code structure to better visualize the proposed changes:
The idea is quite simple, but the execution is a bit tricky due to the nature of
async
code in python where an async call can't call a non-async function that calls an async function. In the PR I tried to support both using LLMEngine directly as well as the AsyncLLMEngine used for serving.@njhill, I was going to add support to return the processors to the factory, but I realized that it was a little bit more complicated because only the scheduler knows when the sequence is done. Maybe we can add a callback somewhere in the scheduler where we can add the deallocation call. Actually that might be required, because I realized that there is another hidden bug: when sequence are preempted with the recompute policy, that makes the state of the logits processor invalid.