Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampler Throughput Optimization #192

Merged
merged 30 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
SamplingParams,
StoppingCriteria,
)
from ..model.base import ModelArtifactConfig
from ..engine.async_connector import AsyncEngineConnector
from .dependencies import get_async_engine_connector

Expand All @@ -44,7 +45,9 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse
router = APIRouter()


def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams:
def _get_sampling_params(
request: ChatCompletionRequest, model_artifact_config: ModelArtifactConfig
) -> SamplingParams:
sampling_params = SamplingParams(
# These params came from vllm
# TODO(amnalyshe): should they be put into mlc-llm batch serving ChatCompletionRequest?
Expand All @@ -68,6 +71,8 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams:
if request.logprobs:
sampling_params.top_logprobs = request.top_logprobs
sampling_params.logprobs = request.logprobs

sampling_params.vocab_size = model_artifact_config.vocab_size
return sampling_params


Expand Down Expand Up @@ -106,8 +111,10 @@ def random_uuid() -> str:

request_id = f"cmpl-{random_uuid()}"
model_name = request.model

model_artifact_config = async_engine_connector.engine.model_artifact_config
try:
sampling_params = _get_sampling_params(request)
sampling_params = _get_sampling_params(request, model_artifact_config)
except ValueError as e:
raise ValueError(
"""
Expand All @@ -130,7 +137,9 @@ def random_uuid() -> str:
messages=request.messages,
num_sequences=request.n,
sampling_params=sampling_params,
stopping_criteria=StoppingCriteria(max_tokens=request.max_tokens, stop_sequences=stop_sequences),
stopping_criteria=StoppingCriteria(
max_tokens=request.max_tokens, stop_sequences=stop_sequences
),
debug_options=DebugOptions(ignore_eos=ignore_eos),
)
if isinstance(request.messages, str):
Expand Down Expand Up @@ -196,7 +205,9 @@ def create_stream_response(
finish_reason=seq.finish_reason.value
if seq.finish_reason is not None
else None,
logprob_info=Logprobs(content=seq.logprob_info) if seq.logprob_info != [] else None
logprob_info=Logprobs(content=seq.logprob_info)
if seq.logprob_info != []
else None,
)
for seq in res.sequences
]
Expand All @@ -217,7 +228,7 @@ async def collect_result_stream(
finish_reasons = [None] * num_sequences
num_prompt_tokens = 0
num_generated_tokens = [0 for _ in range(num_sequences)]
logprob_infos = [[] for _ in range(num_sequences)] # type: ignore
logprob_infos = [[] for _ in range(num_sequences)] # type: ignore
async for res in result_generator:
# TODO: verify that the request cancellation happens after this returns
if res.error:
Expand All @@ -241,7 +252,9 @@ async def collect_result_stream(
finish_reasons[seq.index] = seq.finish_reason.value # type: ignore

choices = []
for index, (logprob_info_seq, chunks, finish_reason) in enumerate(zip(logprob_infos, sequences, finish_reasons)):
for index, (logprob_info_seq, chunks, finish_reason) in enumerate(
zip(logprob_infos, sequences, finish_reasons)
):
logprobs = None
if logprob_info_seq != []:
logprobs = Logprobs(content=logprob_info_seq)
Expand Down
8 changes: 6 additions & 2 deletions serve/mlc_serve/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
PROMPT_SEQEUNCE_INDEX,
get_prompt_sequence_id,
RawLogprobsInfo,
RawLogprobsInfos,
)
from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX
from .sampling_params import (
SamplingParams,
SamplingType,
LOGPROB_TOP_K_MAX,
_SAMPLING_EPS as SAMPLING_EPS,
)
35 changes: 25 additions & 10 deletions serve/mlc_serve/engine/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

ResultQueue = asyncio.Queue[RequestOutput]


class AsyncEngineConnector:
def __init__(self, engine: InferenceEngine, engine_wait_timeout=1):
self.engine = engine
Expand All @@ -32,7 +33,10 @@ async def start(self):
"""
Needs to be called in the thread with event loop
"""
LOG.info("Starting AsyncEngineConnector.", engine_wait_timeout=self.engine_wait_timeout)
LOG.info(
"Starting AsyncEngineConnector.",
engine_wait_timeout=self.engine_wait_timeout,
)
if self.engine_loop_task is not None:
return

Expand Down Expand Up @@ -78,7 +82,7 @@ async def generate(self, request: Request) -> AsyncIterator[RequestOutput]:
try:
queue = await self._add_request(request)
while True:
# TODO(jknight): Should make sure we are catching cancellations
# TODO(jknight): Should make sure we are catching cancellations
# correctly inside this _get_queue...(...) awaitable object too
output = await self._get_queue_item_until_stopped(queue)
if output.error is not None:
Expand All @@ -87,21 +91,32 @@ async def generate(self, request: Request) -> AsyncIterator[RequestOutput]:
if output.is_finished:
return
except asyncio.CancelledError:
LOG.info("AsyncEngineConnector.generate iterator cancelled.", request_id=request.request_id)
LOG.info(
"AsyncEngineConnector.generate iterator cancelled.",
request_id=request.request_id,
)
# Running this sync because `await` inside of cancellation events is problematic
self.engine.cancel(request.request_id)
LOG.info("AsyncEngineConnector.generate request sucessfully cancelled.", request_id=request.request_id)
LOG.info(
"AsyncEngineConnector.generate request sucessfully cancelled.",
request_id=request.request_id,
)
self.recent_cancelled_requests.appendleft(request.request_id)
# Always re-raise CancellationErrors unless you know what you're doing.
raise
finally:
LOG.info("AsyncEngineConnector.generate removing request from result queue.", request_id=request.request_id)
LOG.info(
"AsyncEngineConnector.generate removing request from result queue.",
request_id=request.request_id,
)
self.result_queues.pop(request.request_id, None)

async def _get_queue_item_until_stopped(self, queue: ResultQueue) -> RequestOutput:
try:
get_queue_task = asyncio.create_task(queue.get(), name="get_queue_task")
wait_shutdown_task = asyncio.create_task(self.shutdown_event.wait(), name="wait_shutdown_task")
wait_shutdown_task = asyncio.create_task(
self.shutdown_event.wait(), name="wait_shutdown_task"
)

await asyncio.wait(
(get_queue_task, wait_shutdown_task),
Expand All @@ -110,7 +125,9 @@ async def _get_queue_item_until_stopped(self, queue: ResultQueue) -> RequestOutp

if wait_shutdown_task.done():
if self.engine_loop_exception is not None:
raise EngineException("raised while handling previous engine loop exception") from self.engine_loop_exception
raise EngineException(
"raised while handling previous engine loop exception"
) from self.engine_loop_exception
else:
raise EngineException("stopped with no exception")

Expand Down Expand Up @@ -147,8 +164,6 @@ async def _dispatch_result(self, result: InferenceStepResult) -> None:
elif request_id not in self.recent_cancelled_requests:
# if request is not in result_queues, and not cancelled recently,
# something goes wrong and we want to be aware of it.
LOG.warn(
f"Unknown request id when dispatching result: {request_id}"
)
LOG.warn(f"Unknown request id when dispatching result: {request_id}")

await asyncio.gather(*coroutines)
15 changes: 10 additions & 5 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations
import structlog
import torch
from dataclasses import dataclass, field
from enum import Enum
from abc import ABC, abstractmethod

from typing import List, Callable, Any, Optional, Dict
import inspect
import numpy as np

from .sampling_params import SamplingParams, SamplingType
from ..openai_logprob_protocol import LogprobsContent
from ..model.base import ModelArtifactConfig

LOG = structlog.stdlib.get_logger(__name__)
RequestId = str
Expand All @@ -19,10 +20,8 @@
class RawLogprobsInfo:
current_token_id: int
current_logprob: float
top_token_ids: Optional[np.ndarray]
top_logprobs: Optional[np.ndarray]

RawLogprobsInfos = List[Optional[RawLogprobsInfo]]
sunggg marked this conversation as resolved.
Show resolved Hide resolved
top_token_ids: Optional[torch.Tensor]
top_logprobs: Optional[torch.Tensor]

sunggg marked this conversation as resolved.
Show resolved Hide resolved

# TODO(@sunggg): consider transition to something like Pydantic
Expand Down Expand Up @@ -201,6 +200,12 @@ class InferenceStepResult:


class InferenceEngine(ABC):
"""
Expose the model config to the high-level APIs.
"""

model_artifact_config: ModelArtifactConfig

@abstractmethod
def add(self, requests: list[Request]) -> None:
"""
Expand Down
Loading
Loading