diff --git a/tests/entrypoints/test_llm_generate_multiple_loras.py b/tests/entrypoints/test_llm_generate_multiple_loras.py new file mode 100644 index 0000000000000..b429b904c7c35 --- /dev/null +++ b/tests/entrypoints/test_llm_generate_multiple_loras.py @@ -0,0 +1,69 @@ +import weakref + +import pytest +# downloading lora to test lora requests +from huggingface_hub import snapshot_download + +from vllm import LLM +from vllm.lora.request import LoRARequest + +from ..conftest import cleanup + +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" + +PROMPTS = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +LORA_NAME = "typeof/zephyr-7b-beta-lora" + +pytestmark = pytest.mark.llm + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + tensor_parallel_size=1, + max_model_len=8192, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + max_num_seqs=128, + enforce_eager=True) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup() + + +@pytest.fixture(scope="session") +def zephyr_lora_files(): + return snapshot_download(repo_id=LORA_NAME) + + +@pytest.mark.skip_global_cleanup +def test_multiple_lora_requests(llm: LLM, zephyr_lora_files): + lora_request = [ + LoRARequest(LORA_NAME, idx + 1, zephyr_lora_files) + for idx in range(len(PROMPTS)) + ] + # Multiple SamplingParams should be matched with each prompt + outputs = llm.generate(PROMPTS, lora_request=lora_request) + assert len(PROMPTS) == len(outputs) + + # Exception raised, if the size of params does not match the size of prompts + with pytest.raises(ValueError): + outputs = llm.generate(PROMPTS, lora_request=lora_request[:1]) + + # Single LoRARequest should be applied to every prompt + single_lora_request = lora_request[0] + outputs = llm.generate(PROMPTS, lora_request=single_lora_request) + assert len(PROMPTS) == len(outputs) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 905c36afde1e0..411d5256b75b9 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -170,7 +170,7 @@ def generate( List[SamplingParams]]] = None, prompt_token_ids: Optional[List[int]] = None, use_tqdm: bool = True, - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[RequestOutput]: ... @@ -182,7 +182,7 @@ def generate( List[SamplingParams]]] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[RequestOutput]: ... @@ -195,7 +195,7 @@ def generate( *, prompt_token_ids: List[int], use_tqdm: bool = True, - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[RequestOutput]: ... @@ -208,7 +208,7 @@ def generate( *, prompt_token_ids: List[List[int]], use_tqdm: bool = True, - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[RequestOutput]: ... @@ -219,7 +219,7 @@ def generate( sampling_params: None, prompt_token_ids: Union[List[int], List[List[int]]], use_tqdm: bool = True, - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[RequestOutput]: ... @@ -232,7 +232,7 @@ def generate( sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, use_tqdm: bool = True, - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[RequestOutput]: ... @@ -249,7 +249,7 @@ def generate( Sequence[SamplingParams]]] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -312,7 +312,7 @@ def encode( Sequence[PoolingParams]]] = None, prompt_token_ids: Optional[List[int]] = None, use_tqdm: bool = True, - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[EmbeddingRequestOutput]: ... @@ -324,7 +324,7 @@ def encode( Sequence[PoolingParams]]] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[EmbeddingRequestOutput]: ... @@ -337,7 +337,7 @@ def encode( *, prompt_token_ids: List[int], use_tqdm: bool = True, - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[EmbeddingRequestOutput]: ... @@ -350,7 +350,7 @@ def encode( *, prompt_token_ids: List[List[int]], use_tqdm: bool = True, - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[EmbeddingRequestOutput]: ... @@ -361,7 +361,7 @@ def encode( pooling_params: None, prompt_token_ids: Union[List[int], List[List[int]]], use_tqdm: bool = True, - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[EmbeddingRequestOutput]: ... @@ -374,7 +374,7 @@ def encode( pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, use_tqdm: bool = True, - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[EmbeddingRequestOutput]: ... @@ -391,7 +391,7 @@ def encode( Sequence[PoolingParams]]] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[EmbeddingRequestOutput]: """Generates the completions for the input prompts. @@ -498,7 +498,7 @@ def _validate_and_add_requests( inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], - lora_request: Optional[LoRARequest], + lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], ) -> None: if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. @@ -509,20 +509,25 @@ def _validate_and_add_requests( if isinstance(params, list) and len(params) != num_requests: raise ValueError("The lengths of prompts and params " "must be the same.") + if isinstance(lora_request, + list) and len(lora_request) != num_requests: + raise ValueError("The lengths of prompts and lora_request " + "must be the same.") # Add requests to the engine. for i, request_inputs in enumerate(inputs): self._add_request( request_inputs, params[i] if isinstance(params, Sequence) else params, - lora_request=lora_request, + lora_request=lora_request[i] if isinstance( + lora_request, Sequence) else lora_request, ) def _add_request( self, inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request(request_id,