Skip to content

Commit

Permalink
Sampler Throughput Optimization (#192)
Browse files Browse the repository at this point in the history
* wip

* cleanup and fix the fallback path

* more cleanup

* more cleanup

* add sampler testcase

* wip - resolve merge conflict

* resolved merge conflict, move metadata preparation earlier

* more cleanup

* wip:logprob

* logprob works for sync

* remove unncessary output tokens

* multi-decode, fallback path

* tested fallback

* wip

* remove the custom multinomial since there was no much impact. Add sampling tests

* more sampler test cases

* wip:mypy

* fix mypy

* fix

* expose model config to higher-level APIs

* reflect feedback

* fix mypy

* remove index shuffling in random/greedy

* wip

* better

* better

* more comments

* reflect feedback

* reflect comment
  • Loading branch information
sunggg authored Feb 8, 2024
1 parent 6f4f123 commit eae6ac4
Show file tree
Hide file tree
Showing 16 changed files with 1,279 additions and 560 deletions.
25 changes: 19 additions & 6 deletions serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
SamplingParams,
StoppingCriteria,
)
from ..model.base import ModelArtifactConfig
from ..engine.async_connector import AsyncEngineConnector
from .dependencies import get_async_engine_connector

Expand All @@ -44,7 +45,9 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse
router = APIRouter()


def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams:
def _get_sampling_params(
request: ChatCompletionRequest, model_artifact_config: ModelArtifactConfig
) -> SamplingParams:
sampling_params = SamplingParams(
# These params came from vllm
# TODO(amnalyshe): should they be put into mlc-llm batch serving ChatCompletionRequest?
Expand All @@ -68,6 +71,8 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams:
if request.logprobs:
sampling_params.top_logprobs = request.top_logprobs
sampling_params.logprobs = request.logprobs

sampling_params.vocab_size = model_artifact_config.vocab_size
return sampling_params


Expand Down Expand Up @@ -106,8 +111,10 @@ def random_uuid() -> str:

request_id = f"cmpl-{random_uuid()}"
model_name = request.model

model_artifact_config = async_engine_connector.engine.model_artifact_config
try:
sampling_params = _get_sampling_params(request)
sampling_params = _get_sampling_params(request, model_artifact_config)
except ValueError as e:
raise ValueError(
"""
Expand All @@ -130,7 +137,9 @@ def random_uuid() -> str:
messages=request.messages,
num_sequences=request.n,
sampling_params=sampling_params,
stopping_criteria=StoppingCriteria(max_tokens=request.max_tokens, stop_sequences=stop_sequences),
stopping_criteria=StoppingCriteria(
max_tokens=request.max_tokens, stop_sequences=stop_sequences
),
debug_options=DebugOptions(ignore_eos=ignore_eos),
)
if isinstance(request.messages, str):
Expand Down Expand Up @@ -196,7 +205,9 @@ def create_stream_response(
finish_reason=seq.finish_reason.value
if seq.finish_reason is not None
else None,
logprob_info=Logprobs(content=seq.logprob_info) if seq.logprob_info != [] else None
logprob_info=Logprobs(content=seq.logprob_info)
if seq.logprob_info != []
else None,
)
for seq in res.sequences
]
Expand All @@ -217,7 +228,7 @@ async def collect_result_stream(
finish_reasons = [None] * num_sequences
num_prompt_tokens = 0
num_generated_tokens = [0 for _ in range(num_sequences)]
logprob_infos = [[] for _ in range(num_sequences)] # type: ignore
logprob_infos = [[] for _ in range(num_sequences)] # type: ignore
async for res in result_generator:
# TODO: verify that the request cancellation happens after this returns
if res.error:
Expand All @@ -241,7 +252,9 @@ async def collect_result_stream(
finish_reasons[seq.index] = seq.finish_reason.value # type: ignore

choices = []
for index, (logprob_info_seq, chunks, finish_reason) in enumerate(zip(logprob_infos, sequences, finish_reasons)):
for index, (logprob_info_seq, chunks, finish_reason) in enumerate(
zip(logprob_infos, sequences, finish_reasons)
):
logprobs = None
if logprob_info_seq != []:
logprobs = Logprobs(content=logprob_info_seq)
Expand Down
8 changes: 6 additions & 2 deletions serve/mlc_serve/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
PROMPT_SEQEUNCE_INDEX,
get_prompt_sequence_id,
RawLogprobsInfo,
RawLogprobsInfos,
)
from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX
from .sampling_params import (
SamplingParams,
SamplingType,
LOGPROB_TOP_K_MAX,
_SAMPLING_EPS as SAMPLING_EPS,
)
35 changes: 25 additions & 10 deletions serve/mlc_serve/engine/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

ResultQueue = asyncio.Queue[RequestOutput]


class AsyncEngineConnector:
def __init__(self, engine: InferenceEngine, engine_wait_timeout=1):
self.engine = engine
Expand All @@ -32,7 +33,10 @@ async def start(self):
"""
Needs to be called in the thread with event loop
"""
LOG.info("Starting AsyncEngineConnector.", engine_wait_timeout=self.engine_wait_timeout)
LOG.info(
"Starting AsyncEngineConnector.",
engine_wait_timeout=self.engine_wait_timeout,
)
if self.engine_loop_task is not None:
return

Expand Down Expand Up @@ -78,7 +82,7 @@ async def generate(self, request: Request) -> AsyncIterator[RequestOutput]:
try:
queue = await self._add_request(request)
while True:
# TODO(jknight): Should make sure we are catching cancellations
# TODO(jknight): Should make sure we are catching cancellations
# correctly inside this _get_queue...(...) awaitable object too
output = await self._get_queue_item_until_stopped(queue)
if output.error is not None:
Expand All @@ -87,21 +91,32 @@ async def generate(self, request: Request) -> AsyncIterator[RequestOutput]:
if output.is_finished:
return
except asyncio.CancelledError:
LOG.info("AsyncEngineConnector.generate iterator cancelled.", request_id=request.request_id)
LOG.info(
"AsyncEngineConnector.generate iterator cancelled.",
request_id=request.request_id,
)
# Running this sync because `await` inside of cancellation events is problematic
self.engine.cancel(request.request_id)
LOG.info("AsyncEngineConnector.generate request sucessfully cancelled.", request_id=request.request_id)
LOG.info(
"AsyncEngineConnector.generate request sucessfully cancelled.",
request_id=request.request_id,
)
self.recent_cancelled_requests.appendleft(request.request_id)
# Always re-raise CancellationErrors unless you know what you're doing.
raise
finally:
LOG.info("AsyncEngineConnector.generate removing request from result queue.", request_id=request.request_id)
LOG.info(
"AsyncEngineConnector.generate removing request from result queue.",
request_id=request.request_id,
)
self.result_queues.pop(request.request_id, None)

async def _get_queue_item_until_stopped(self, queue: ResultQueue) -> RequestOutput:
try:
get_queue_task = asyncio.create_task(queue.get(), name="get_queue_task")
wait_shutdown_task = asyncio.create_task(self.shutdown_event.wait(), name="wait_shutdown_task")
wait_shutdown_task = asyncio.create_task(
self.shutdown_event.wait(), name="wait_shutdown_task"
)

await asyncio.wait(
(get_queue_task, wait_shutdown_task),
Expand All @@ -110,7 +125,9 @@ async def _get_queue_item_until_stopped(self, queue: ResultQueue) -> RequestOutp

if wait_shutdown_task.done():
if self.engine_loop_exception is not None:
raise EngineException("raised while handling previous engine loop exception") from self.engine_loop_exception
raise EngineException(
"raised while handling previous engine loop exception"
) from self.engine_loop_exception
else:
raise EngineException("stopped with no exception")

Expand Down Expand Up @@ -147,8 +164,6 @@ async def _dispatch_result(self, result: InferenceStepResult) -> None:
elif request_id not in self.recent_cancelled_requests:
# if request is not in result_queues, and not cancelled recently,
# something goes wrong and we want to be aware of it.
LOG.warn(
f"Unknown request id when dispatching result: {request_id}"
)
LOG.warn(f"Unknown request id when dispatching result: {request_id}")

await asyncio.gather(*coroutines)
15 changes: 10 additions & 5 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations
import structlog
import torch
from dataclasses import dataclass, field
from enum import Enum
from abc import ABC, abstractmethod

from typing import List, Callable, Any, Optional, Dict
import inspect
import numpy as np

from .sampling_params import SamplingParams, SamplingType
from ..openai_logprob_protocol import LogprobsContent
from ..model.base import ModelArtifactConfig

LOG = structlog.stdlib.get_logger(__name__)
RequestId = str
Expand All @@ -19,10 +20,8 @@
class RawLogprobsInfo:
current_token_id: int
current_logprob: float
top_token_ids: Optional[np.ndarray]
top_logprobs: Optional[np.ndarray]

RawLogprobsInfos = List[Optional[RawLogprobsInfo]]
top_token_ids: Optional[torch.Tensor]
top_logprobs: Optional[torch.Tensor]


# TODO(@sunggg): consider transition to something like Pydantic
Expand Down Expand Up @@ -201,6 +200,12 @@ class InferenceStepResult:


class InferenceEngine(ABC):
"""
Expose the model config to the high-level APIs.
"""

model_artifact_config: ModelArtifactConfig

@abstractmethod
def add(self, requests: list[Request]) -> None:
"""
Expand Down
Loading

0 comments on commit eae6ac4

Please sign in to comment.