-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug]: AssertionError when using automatic prefix caching and prompt_logprobs #8268
Comments
probably similar issue to #5344 (same assert fails) some more related issues come up when searching for |
Note sure if it's any help, but I simplified the example a little bit. If the number of tokens in the prefix is > 16 and there's a full cache hit, then the assertion will trigger. from vllm import LLM, SamplingParams, TokensPrompt
model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model = LLM(model_path, tensor_parallel_size=1, dtype="bfloat16", gpu_memory_utilization=0.8, enable_prefix_caching=True, enable_chunked_prefill=True,)
sampling_params = SamplingParams(prompt_logprobs=1, max_tokens=1)
# works
# prompt = TokensPrompt(prompt_token_ids=list(range(16)))
# model.generate(prompt, sampling_params, use_tqdm=False)
# print("OK")
# model.generate(prompt, sampling_params, use_tqdm=False)
# print("OK")
# fails
prompt = TokensPrompt(prompt_token_ids=list(range(17)))
x = model.generate(prompt, sampling_params, use_tqdm=False)
print("OK")
y = model.generate(prompt, sampling_params, use_tqdm=False)
print("OK") |
Another update, it looks like the crash is related to the block size. If the number of tokens in the cached prefix is > than the block size, then the assertion will be hit. 16 is the default so that's why I saw it first. As per the example below, if I use a block size of 32, then I can increase the length of TokensPrompt to 32. Examples: from vllm import LLM, SamplingParams, TokensPrompt
model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model = LLM(
model_path,
tensor_parallel_size=1,
dtype="bfloat16",
gpu_memory_utilization=0.8,
enable_prefix_caching=True,
enable_chunked_prefill=True,
block_size=32
)
sampling_params = SamplingParams(prompt_logprobs=1, max_tokens=1)
# works
prompt = TokensPrompt(prompt_token_ids=list(range(31)))
x = model.generate(prompt, sampling_params, use_tqdm=False)
print(x[0].prompt_logprobs)
y = model.generate(prompt, sampling_params, use_tqdm=False)
print(x[0].prompt_logprobs)
# fails
prompt = TokensPrompt(prompt_token_ids=list(range(33)))
x = model.generate(prompt, sampling_params, use_tqdm=False)
print(x[0].prompt_logprobs)
y = model.generate(prompt, sampling_params, use_tqdm=False)
print(x[0].prompt_logprobs) |
Can you try out the new version of vLLM (0.6.3.post1). I believe #9034 may have fixed this error by correctly populating Sequence. |
The #9034 cannot fix the issue, I patched this PR but still reproduce the issue. |
Unfortunately, I saw the same. I think I got lucky when it worked out. |
posted a fix in #3251 that solves some problems (maybe enough for you), but not all |
@ccolas this looks great. |
Your current environment
The output of `python collect_env.py`
🐛 Describe the bug
I'm having issues using automatic prefix caching with prompt_logprobs option. The first call to the
generate
method goes through, but the second call errors with anAssertionError
.Reproduction code:
Full stack trace:
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: