Skip to content

Commit

Permalink
[Frontend] enable passing multiple LoRA adapters at once to generate() (
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoldey authored Jun 6, 2024
1 parent abe855d commit 828da0d
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 17 deletions.
69 changes: 69 additions & 0 deletions tests/entrypoints/test_llm_generate_multiple_loras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import weakref

import pytest
# downloading lora to test lora requests
from huggingface_hub import snapshot_download

from vllm import LLM
from vllm.lora.request import LoRARequest

from ..conftest import cleanup

MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"

PROMPTS = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

LORA_NAME = "typeof/zephyr-7b-beta-lora"

pytestmark = pytest.mark.llm


@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME,
tensor_parallel_size=1,
max_model_len=8192,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
max_num_seqs=128,
enforce_eager=True)

with llm.deprecate_legacy_api():
yield weakref.proxy(llm)

del llm

cleanup()


@pytest.fixture(scope="session")
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)


@pytest.mark.skip_global_cleanup
def test_multiple_lora_requests(llm: LLM, zephyr_lora_files):
lora_request = [
LoRARequest(LORA_NAME, idx + 1, zephyr_lora_files)
for idx in range(len(PROMPTS))
]
# Multiple SamplingParams should be matched with each prompt
outputs = llm.generate(PROMPTS, lora_request=lora_request)
assert len(PROMPTS) == len(outputs)

# Exception raised, if the size of params does not match the size of prompts
with pytest.raises(ValueError):
outputs = llm.generate(PROMPTS, lora_request=lora_request[:1])

# Single LoRARequest should be applied to every prompt
single_lora_request = lora_request[0]
outputs = llm.generate(PROMPTS, lora_request=single_lora_request)
assert len(PROMPTS) == len(outputs)
39 changes: 22 additions & 17 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def generate(
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...

Expand All @@ -182,7 +182,7 @@ def generate(
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...

Expand All @@ -195,7 +195,7 @@ def generate(
*,
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...

Expand All @@ -208,7 +208,7 @@ def generate(
*,
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...

Expand All @@ -219,7 +219,7 @@ def generate(
sampling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...

Expand All @@ -232,7 +232,7 @@ def generate(
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...

Expand All @@ -249,7 +249,7 @@ def generate(
Sequence[SamplingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
Expand Down Expand Up @@ -312,7 +312,7 @@ def encode(
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...

Expand All @@ -324,7 +324,7 @@ def encode(
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...

Expand All @@ -337,7 +337,7 @@ def encode(
*,
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...

Expand All @@ -350,7 +350,7 @@ def encode(
*,
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...

Expand All @@ -361,7 +361,7 @@ def encode(
pooling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...

Expand All @@ -374,7 +374,7 @@ def encode(
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...

Expand All @@ -391,7 +391,7 @@ def encode(
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
"""Generates the completions for the input prompts.
Expand Down Expand Up @@ -498,7 +498,7 @@ def _validate_and_add_requests(
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[LoRARequest],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
) -> None:
if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list.
Expand All @@ -509,20 +509,25 @@ def _validate_and_add_requests(
if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
if isinstance(lora_request,
list) and len(lora_request) != num_requests:
raise ValueError("The lengths of prompts and lora_request "
"must be the same.")

# Add requests to the engine.
for i, request_inputs in enumerate(inputs):
self._add_request(
request_inputs,
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
)

def _add_request(
self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id,
Expand Down

0 comments on commit 828da0d

Please sign in to comment.