-
-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
prefill only inputs (data & preprocessor)
- Loading branch information
Showing
9 changed files
with
347 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# mypy: ignore-errors | ||
import pytest | ||
|
||
from vllm.inputs.prefill_only.data import (TextOnlyInputs, TextPrompt, | ||
TokensPrompt, ValidationError) | ||
from vllm.inputs.prefill_only.preprocessor import TextInputProcessor | ||
|
||
input_processor = TextInputProcessor() | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def request_id(): | ||
return "0" | ||
|
||
|
||
def test_input_processor_1(request_id): | ||
prompt = "test" | ||
request = input_processor(request_id, prompt) | ||
|
||
assert request.inputs == {"prompt": prompt} | ||
|
||
|
||
def test_input_processor_2(request_id): | ||
prompt = "test" | ||
inputs = TextPrompt(prompt=prompt) | ||
request = input_processor(request_id, inputs) | ||
|
||
assert request.inputs == {"prompt": prompt} | ||
|
||
|
||
def test_input_processor_3(request_id): | ||
prompt_token_ids = [0] | ||
inputs = TokensPrompt(prompt_token_ids=prompt_token_ids) | ||
request = input_processor(request_id, inputs) | ||
|
||
assert request.inputs == {"prompt_token_ids": prompt_token_ids} | ||
|
||
|
||
def test_input_processor_4(request_id): | ||
prompt = "test" | ||
prompt_token_ids = [0] | ||
inputs = TextOnlyInputs(prompt_token_ids=prompt_token_ids) | ||
request = input_processor(request_id, inputs) | ||
|
||
assert request.inputs == {"prompt_token_ids": prompt_token_ids} | ||
|
||
inputs = TextOnlyInputs(prompt_token_ids=prompt_token_ids, prompt=prompt) | ||
request = input_processor(request_id, inputs) | ||
|
||
assert request.inputs == { | ||
"prompt_token_ids": prompt_token_ids, | ||
"prompt": prompt | ||
} | ||
|
||
|
||
def test_input_processor_5(request_id): | ||
prompt = "test" | ||
prompt_token_ids = [0] | ||
inputs = {"prompt_token_ids": prompt_token_ids, "prompt": prompt} | ||
|
||
request = input_processor(request_id, inputs) | ||
|
||
assert request.inputs == inputs | ||
|
||
|
||
def test_validation_error(request_id): | ||
with pytest.raises(ValidationError): | ||
inputs = {} | ||
input_processor(request_id, inputs) | ||
|
||
with pytest.raises(ValidationError): | ||
inputs = {"foo": "bar"} | ||
input_processor(request_id, inputs) | ||
|
||
with pytest.raises(ValidationError): | ||
inputs = 0 | ||
input_processor(request_id, inputs) | ||
|
||
with pytest.raises(ValidationError): | ||
inputs = 0.0 | ||
input_processor(request_id, inputs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import pytest | ||
|
||
from vllm.inputs.prefill_only.data import TextOnlyInputs, TokensPrompt | ||
from vllm.inputs.prefill_only.preprocessor import (TextInputProcessor, | ||
TextRequestProcessor) | ||
from vllm.inputs.prefill_only.tokenizer import Tokenizer | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def request_id(): | ||
return "0" | ||
|
||
|
||
TOKENIZER_NAMES = ["facebook/opt-125m", "gpt2"] | ||
|
||
|
||
@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES) | ||
def test_request_processor(request_id: str, tokenizer_name: str): | ||
|
||
tokenizer = Tokenizer(tokenizer_name=tokenizer_name) | ||
input_processor = TextInputProcessor() | ||
request_processor = TextRequestProcessor(tokenizer) | ||
|
||
prompt = "test" | ||
request = input_processor(request_id, prompt) | ||
|
||
assert request.inputs == {"prompt": prompt} | ||
|
||
schedulable_request = request_processor(request) | ||
|
||
assert isinstance(schedulable_request.inputs, TextOnlyInputs) | ||
assert len(schedulable_request.inputs.prompt_token_ids) > 0 | ||
|
||
prompt_token_ids = [0] | ||
request = input_processor(request_id, | ||
TokensPrompt(prompt_token_ids=prompt_token_ids)) | ||
|
||
schedulable_request = request_processor(request) | ||
|
||
assert isinstance(schedulable_request.inputs, TextOnlyInputs) | ||
assert len(schedulable_request.inputs.prompt_token_ids) > 0 |
File renamed without changes.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from dataclasses import dataclass | ||
from typing import Dict, List, Optional, Union | ||
|
||
|
||
class Params: | ||
pass | ||
|
||
|
||
class Inputs: | ||
pass | ||
|
||
|
||
@dataclass | ||
class TextPrompt(Inputs): | ||
"""Schema for a text prompt.""" | ||
|
||
prompt: str | ||
"""The input text to be tokenized before passing to the model.""" | ||
|
||
|
||
@dataclass | ||
class TokensPrompt(Inputs): | ||
"""Schema for a tokenized prompt.""" | ||
|
||
prompt_token_ids: List[int] | ||
"""A list of token IDs to pass to the model.""" | ||
|
||
|
||
@dataclass | ||
class TextOnlyInputs(Inputs): | ||
prompt_token_ids: List[int] | ||
"""The token IDs of the prompt.""" | ||
|
||
prompt: Optional[str] = None | ||
""" | ||
The original prompt text corresponding to the token IDs, if available. | ||
""" | ||
|
||
|
||
PromptInput = Union[str, Dict, TextPrompt, TokensPrompt, TextOnlyInputs] | ||
|
||
|
||
@dataclass | ||
class Request: | ||
request_id: str | ||
arrival_time: float | ||
|
||
|
||
@dataclass | ||
class TextRequest(Request): | ||
inputs: Dict | ||
|
||
|
||
class ValidationError(ValueError): | ||
pass | ||
|
||
|
||
class SchedulableRequest(Request): | ||
pass | ||
|
||
|
||
@dataclass | ||
class TextSchedulableRequest(SchedulableRequest): | ||
inputs: TextOnlyInputs | ||
|
||
@property | ||
def num_new_tokens(self): | ||
return len(self.inputs.prompt_token_ids) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import time | ||
from abc import ABC, abstractmethod | ||
from typing import Any, Dict, Optional, cast | ||
|
||
from vllm.inputs.prefill_only.data import (Params, PromptInput, Request, | ||
SchedulableRequest, TextOnlyInputs, | ||
TextPrompt, TextRequest, | ||
TextSchedulableRequest, | ||
TokensPrompt, ValidationError) | ||
from vllm.inputs.prefill_only.tokenizer import Tokenizer | ||
|
||
|
||
class InputProcessor(ABC): | ||
""" | ||
Input(request_id, inputs, params, arrival_time) -> InputProcessor -> Request | ||
""" | ||
|
||
@abstractmethod | ||
def __call__(self, | ||
request_id: str, | ||
inputs: Optional[Any] = None, | ||
params: Optional[Params] = None, | ||
arrival_time: Optional[float] = None) -> Request: | ||
raise NotImplementedError | ||
|
||
@classmethod | ||
@abstractmethod | ||
def from_engine(cls, engine): | ||
raise NotImplementedError | ||
|
||
|
||
class TextInputProcessor(InputProcessor): | ||
|
||
def __call__(self, | ||
request_id: str, | ||
inputs: Optional[PromptInput] = None, | ||
params: Optional[Params] = None, | ||
arrival_time: Optional[float] = None) -> TextRequest: | ||
|
||
if isinstance(inputs, str): | ||
inputs = {"prompt": inputs} | ||
elif isinstance(inputs, TextPrompt): | ||
inputs = {"prompt": inputs.prompt} | ||
elif isinstance(inputs, TokensPrompt): | ||
inputs = {"prompt_token_ids": inputs.prompt_token_ids} | ||
elif isinstance(inputs, TextOnlyInputs): | ||
_inputs: Dict[str, Any] = { | ||
"prompt_token_ids": inputs.prompt_token_ids | ||
} | ||
|
||
if inputs.prompt is not None: | ||
_inputs["prompt"] = inputs.prompt | ||
|
||
inputs = _inputs | ||
|
||
elif isinstance(inputs, dict): | ||
if "prompt" not in inputs and "prompt_token_ids" not in inputs: | ||
raise ValidationError('"prompt" and "prompt_token_ids" ' | ||
'have at least one in inputs.') | ||
inputs = { | ||
k: v | ||
for k, v in inputs.items() | ||
if k in {"prompt", "prompt_token_ids"} | ||
} | ||
else: | ||
raise ValidationError( | ||
f"Input does not support {type(inputs)} data type") | ||
|
||
if not arrival_time: | ||
arrival_time = time.time() | ||
request = TextRequest(request_id=str(request_id), | ||
inputs=inputs, | ||
arrival_time=arrival_time) | ||
return request | ||
|
||
@classmethod | ||
def from_engine(cls, engine): | ||
return cls() | ||
|
||
|
||
class RequestProcessor(ABC): | ||
""" | ||
Request -> RequestProcessor -> SchedulableRequest | ||
""" | ||
|
||
@abstractmethod | ||
def __call__(self, request: Request) -> SchedulableRequest: | ||
raise NotImplementedError | ||
|
||
@classmethod | ||
@abstractmethod | ||
def from_engine(cls, engine): | ||
raise NotImplementedError | ||
|
||
|
||
class TextRequestProcessor(RequestProcessor): | ||
|
||
def __init__(self, tokenizer: Tokenizer): | ||
self.tokenizer = tokenizer | ||
|
||
def __call__(self, request: Request) -> TextSchedulableRequest: | ||
assert isinstance(request, TextRequest) | ||
|
||
request = cast(TextRequest, request) | ||
|
||
inputs = request.inputs | ||
|
||
if "prompt_token_ids" not in inputs: | ||
tokenizer = self.tokenizer | ||
|
||
prompt_token_ids = tokenizer.encode(inputs["prompt"]) | ||
else: | ||
prompt_token_ids = inputs["prompt_token_ids"] | ||
|
||
schedulable_request = TextSchedulableRequest( | ||
request_id=request.request_id, | ||
inputs=TextOnlyInputs(prompt_token_ids=prompt_token_ids, | ||
prompt=inputs.get("prompt")), | ||
arrival_time=request.arrival_time) | ||
|
||
return schedulable_request | ||
|
||
@classmethod | ||
def from_engine(cls, engine): | ||
return cls(engine.tokenizer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from vllm.transformers_utils.tokenizer import get_tokenizer | ||
|
||
|
||
class Tokenizer: | ||
|
||
def __init__(self, tokenizer_name: str, **kwargs): | ||
self.tokenizer_name = tokenizer_name | ||
self.tokenizer_kwargs = kwargs | ||
|
||
self.tokenizer = get_tokenizer(tokenizer_name=self.tokenizer_name, | ||
**self.tokenizer_kwargs) | ||
|
||
@classmethod | ||
def from_engine(cls, engine): | ||
init_kwargs = dict( | ||
tokenizer_name=engine.engine_config.model_config.tokenizer, | ||
tokenizer_mode=engine.engine_config.model_config.tokenizer_mode, | ||
trust_remote_code=engine.engine_config.model_config. | ||
trust_remote_code, | ||
revision=engine.engine_config.model_config.tokenizer_revision) | ||
|
||
return cls(**init_kwargs) | ||
|
||
def __call__(self, *args, **kwargs): | ||
return self.tokenizer(*args, **kwargs) | ||
|
||
def encode(self, *args, **kwargs): | ||
return self.tokenizer.encode(*args, **kwargs) | ||
|
||
@property | ||
def eos_token_id(self): | ||
return self.tokenizer.eos_token_id |