diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index 1c558609c8..ae4f40a32b 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -30,6 +30,7 @@ SamplingParams, StoppingCriteria, ) +from ..model.base import ModelArtifactConfig from ..engine.async_connector import AsyncEngineConnector from .dependencies import get_async_engine_connector @@ -44,7 +45,9 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse router = APIRouter() -def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: +def _get_sampling_params( + request: ChatCompletionRequest, model_artifact_config: ModelArtifactConfig +) -> SamplingParams: sampling_params = SamplingParams( # These params came from vllm # TODO(amnalyshe): should they be put into mlc-llm batch serving ChatCompletionRequest? @@ -68,6 +71,8 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: if request.logprobs: sampling_params.top_logprobs = request.top_logprobs sampling_params.logprobs = request.logprobs + + sampling_params.vocab_size = model_artifact_config.vocab_size return sampling_params @@ -106,8 +111,10 @@ def random_uuid() -> str: request_id = f"cmpl-{random_uuid()}" model_name = request.model + + model_artifact_config = async_engine_connector.engine.model_artifact_config try: - sampling_params = _get_sampling_params(request) + sampling_params = _get_sampling_params(request, model_artifact_config) except ValueError as e: raise ValueError( """ @@ -130,7 +137,9 @@ def random_uuid() -> str: messages=request.messages, num_sequences=request.n, sampling_params=sampling_params, - stopping_criteria=StoppingCriteria(max_tokens=request.max_tokens, stop_sequences=stop_sequences), + stopping_criteria=StoppingCriteria( + max_tokens=request.max_tokens, stop_sequences=stop_sequences + ), debug_options=DebugOptions(ignore_eos=ignore_eos), ) if isinstance(request.messages, str): @@ -196,7 +205,9 @@ def create_stream_response( finish_reason=seq.finish_reason.value if seq.finish_reason is not None else None, - logprob_info=Logprobs(content=seq.logprob_info) if seq.logprob_info != [] else None + logprob_info=Logprobs(content=seq.logprob_info) + if seq.logprob_info != [] + else None, ) for seq in res.sequences ] @@ -217,7 +228,7 @@ async def collect_result_stream( finish_reasons = [None] * num_sequences num_prompt_tokens = 0 num_generated_tokens = [0 for _ in range(num_sequences)] - logprob_infos = [[] for _ in range(num_sequences)] # type: ignore + logprob_infos = [[] for _ in range(num_sequences)] # type: ignore async for res in result_generator: # TODO: verify that the request cancellation happens after this returns if res.error: @@ -241,7 +252,9 @@ async def collect_result_stream( finish_reasons[seq.index] = seq.finish_reason.value # type: ignore choices = [] - for index, (logprob_info_seq, chunks, finish_reason) in enumerate(zip(logprob_infos, sequences, finish_reasons)): + for index, (logprob_info_seq, chunks, finish_reason) in enumerate( + zip(logprob_infos, sequences, finish_reasons) + ): logprobs = None if logprob_info_seq != []: logprobs = Logprobs(content=logprob_info_seq) diff --git a/serve/mlc_serve/engine/__init__.py b/serve/mlc_serve/engine/__init__.py index b2fb08a079..7c448fcd76 100644 --- a/serve/mlc_serve/engine/__init__.py +++ b/serve/mlc_serve/engine/__init__.py @@ -17,6 +17,10 @@ PROMPT_SEQEUNCE_INDEX, get_prompt_sequence_id, RawLogprobsInfo, - RawLogprobsInfos, ) -from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX +from .sampling_params import ( + SamplingParams, + SamplingType, + LOGPROB_TOP_K_MAX, + _SAMPLING_EPS as SAMPLING_EPS, +) diff --git a/serve/mlc_serve/engine/async_connector.py b/serve/mlc_serve/engine/async_connector.py index 8eb2d85168..907e840314 100644 --- a/serve/mlc_serve/engine/async_connector.py +++ b/serve/mlc_serve/engine/async_connector.py @@ -18,6 +18,7 @@ ResultQueue = asyncio.Queue[RequestOutput] + class AsyncEngineConnector: def __init__(self, engine: InferenceEngine, engine_wait_timeout=1): self.engine = engine @@ -32,7 +33,10 @@ async def start(self): """ Needs to be called in the thread with event loop """ - LOG.info("Starting AsyncEngineConnector.", engine_wait_timeout=self.engine_wait_timeout) + LOG.info( + "Starting AsyncEngineConnector.", + engine_wait_timeout=self.engine_wait_timeout, + ) if self.engine_loop_task is not None: return @@ -78,7 +82,7 @@ async def generate(self, request: Request) -> AsyncIterator[RequestOutput]: try: queue = await self._add_request(request) while True: - # TODO(jknight): Should make sure we are catching cancellations + # TODO(jknight): Should make sure we are catching cancellations # correctly inside this _get_queue...(...) awaitable object too output = await self._get_queue_item_until_stopped(queue) if output.error is not None: @@ -87,21 +91,32 @@ async def generate(self, request: Request) -> AsyncIterator[RequestOutput]: if output.is_finished: return except asyncio.CancelledError: - LOG.info("AsyncEngineConnector.generate iterator cancelled.", request_id=request.request_id) + LOG.info( + "AsyncEngineConnector.generate iterator cancelled.", + request_id=request.request_id, + ) # Running this sync because `await` inside of cancellation events is problematic self.engine.cancel(request.request_id) - LOG.info("AsyncEngineConnector.generate request sucessfully cancelled.", request_id=request.request_id) + LOG.info( + "AsyncEngineConnector.generate request sucessfully cancelled.", + request_id=request.request_id, + ) self.recent_cancelled_requests.appendleft(request.request_id) # Always re-raise CancellationErrors unless you know what you're doing. raise finally: - LOG.info("AsyncEngineConnector.generate removing request from result queue.", request_id=request.request_id) + LOG.info( + "AsyncEngineConnector.generate removing request from result queue.", + request_id=request.request_id, + ) self.result_queues.pop(request.request_id, None) async def _get_queue_item_until_stopped(self, queue: ResultQueue) -> RequestOutput: try: get_queue_task = asyncio.create_task(queue.get(), name="get_queue_task") - wait_shutdown_task = asyncio.create_task(self.shutdown_event.wait(), name="wait_shutdown_task") + wait_shutdown_task = asyncio.create_task( + self.shutdown_event.wait(), name="wait_shutdown_task" + ) await asyncio.wait( (get_queue_task, wait_shutdown_task), @@ -110,7 +125,9 @@ async def _get_queue_item_until_stopped(self, queue: ResultQueue) -> RequestOutp if wait_shutdown_task.done(): if self.engine_loop_exception is not None: - raise EngineException("raised while handling previous engine loop exception") from self.engine_loop_exception + raise EngineException( + "raised while handling previous engine loop exception" + ) from self.engine_loop_exception else: raise EngineException("stopped with no exception") @@ -147,8 +164,6 @@ async def _dispatch_result(self, result: InferenceStepResult) -> None: elif request_id not in self.recent_cancelled_requests: # if request is not in result_queues, and not cancelled recently, # something goes wrong and we want to be aware of it. - LOG.warn( - f"Unknown request id when dispatching result: {request_id}" - ) + LOG.warn(f"Unknown request id when dispatching result: {request_id}") await asyncio.gather(*coroutines) diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index a689e43888..f60a894e3b 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -1,15 +1,16 @@ from __future__ import annotations import structlog +import torch from dataclasses import dataclass, field from enum import Enum from abc import ABC, abstractmethod from typing import List, Callable, Any, Optional, Dict import inspect -import numpy as np from .sampling_params import SamplingParams, SamplingType from ..openai_logprob_protocol import LogprobsContent +from ..model.base import ModelArtifactConfig LOG = structlog.stdlib.get_logger(__name__) RequestId = str @@ -19,10 +20,8 @@ class RawLogprobsInfo: current_token_id: int current_logprob: float - top_token_ids: Optional[np.ndarray] - top_logprobs: Optional[np.ndarray] - -RawLogprobsInfos = List[Optional[RawLogprobsInfo]] + top_token_ids: Optional[torch.Tensor] + top_logprobs: Optional[torch.Tensor] # TODO(@sunggg): consider transition to something like Pydantic @@ -201,6 +200,12 @@ class InferenceStepResult: class InferenceEngine(ABC): + """ + Expose the model config to the high-level APIs. + """ + + model_artifact_config: ModelArtifactConfig + @abstractmethod def add(self, requests: list[Request]) -> None: """ diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 4a02ce2f60..96e17eeefc 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -11,13 +11,12 @@ from .base import ( GenerationSequence, - RawLogprobsInfo, - RawLogprobsInfos, Request, RequestId, RequestState, SequenceId, StoppingCriteria, + RawLogprobsInfo, ) from .model_module import ( DecodeRequest, @@ -82,14 +81,19 @@ def detokenize_incrementally( prompt_tokens: list[int], generation_sequence: GenerationSequence, tokenizer: TokenizerP, - skip_special_tokens=False, + new_token_id: Optional[int] = None, + skip_special_tokens: bool = False, ) -> str: - new_token_id = generation_sequence.generated_token_ids[-1] - + # tokenizer.decode() is similar to doing convert_tokens_to_string(convert_ids_to_tokens(token_ids)) + # in this function, we separate these two steps. + is_logprob = new_token_id is not None + new_token_id = ( + generation_sequence.generated_token_ids[-1] if not is_logprob else new_token_id + ) # This is the first iteration for this sequence if generation_sequence.prev_tokens is None: # TODO(masahi): Figure out a way to remove this concat - new_tokens = tokenizer.convert_ids_to_tokens( + new_tokens: List[str] = tokenizer.convert_ids_to_tokens( # type: ignore prompt_tokens + generation_sequence.generated_token_ids ) output_tokens = new_tokens @@ -105,7 +109,9 @@ def detokenize_incrementally( prefix_end_offset = max(len(output_tokens) - 1, 0) else: # Put new_token_id in a list so skip_special_tokens is respected - new_tokens = tokenizer.convert_ids_to_tokens([new_token_id]) + new_tokens: List[str] = tokenizer.convert_ids_to_tokens( # type: ignore + [new_token_id] + ) output_tokens = generation_sequence.prev_tokens + new_tokens prefix_begin_offset = generation_sequence.prefix_begin_offset @@ -131,62 +137,19 @@ def detokenize_incrementally( new_prefix_end_offset = prefix_end_offset delta = "" - generation_sequence.prefix_begin_offset = new_prefix_begin_offset - generation_sequence.prefix_end_offset = new_prefix_end_offset - if generation_sequence.prev_tokens is None: - generation_sequence.prev_tokens = new_tokens - else: - generation_sequence.prev_tokens.extend(new_tokens) + # Update the status + # If this is for logprob, we do not update the status + if not is_logprob: + generation_sequence.prefix_begin_offset = new_prefix_begin_offset + generation_sequence.prefix_end_offset = new_prefix_end_offset + if generation_sequence.prev_tokens is None: + generation_sequence.prev_tokens = new_tokens + else: + generation_sequence.prev_tokens.extend(new_tokens) return delta -def logprob_detokenize( - tokenizer: TokenizerP, - logprob_info: Optional[RawLogprobsInfo], -) -> Optional[LogprobsContent]: - """Detokenize tokens from RawLogprobInfo and convert the latter to LogprobContent""" - if logprob_info is None: - return None - - top_logprobs: List[TopLogprobs] = [] - if logprob_info.top_token_ids is not None and logprob_info.top_logprobs is not None: - top_tokens = list(zip(logprob_info.top_token_ids, logprob_info.top_logprobs)) - for top_token_id, top_logprob in top_tokens: - top_logprobs.append( - TopLogprobs( - token=tokenizer.decode(top_token_id), - logprob=float(top_logprob), - # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object - bytes=None, - ) - ) - - logprobs_content = LogprobsContent( - token=tokenizer.decode([logprob_info.current_token_id]), - logprob=logprob_info.current_logprob, - # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object - bytes=None, - top_logprobs=top_logprobs, - ) - - return logprobs_content - - -def logprobs_detokenize( - tokenizer: TokenizerP, - logprob_info: Optional[RawLogprobsInfos], -) -> List[Optional[LogprobsContent]]: - if logprob_info is None: - return [] - - res: List[Optional[LogprobsContent]] = [] - for info in logprob_info: - res.append(logprob_detokenize(tokenizer, info)) - - return res - - def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended): if stopping_criteria.stop_sequences: for t in stopping_criteria.stop_sequences: @@ -206,13 +169,58 @@ def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended): return output_text, delta, is_ended -def update_sequence( +def prepare_logprob( + logprob_info: Optional[List[Optional[RawLogprobsInfo]]], + delta: str, + gen_seq: GenerationSequence, + prompt_token_ids: List[int], + tokenizer: TokenizerP, +) -> List[Optional[LogprobsContent]]: + if logprob_info is None: + return [] + + outputs = [] + for info in logprob_info: + assert info is not None + assert info.top_token_ids is not None + + top_logprobs: List[TopLogprobs] = [] + if info.top_logprobs is not None: + assert info.top_logprobs is not None + token_ids = info.top_token_ids.cpu().numpy() + logprobs = info.top_logprobs.cpu().numpy() + + for top_token_id, top_logprob in zip(token_ids, logprobs): + top_logprobs.append( + TopLogprobs( + token=detokenize_incrementally( + prompt_token_ids, gen_seq, tokenizer, top_token_id + ), + logprob=float(top_logprob), + # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object + bytes=None, + ) + ) + + logprobs_content = LogprobsContent( + token=delta, + logprob=info.current_logprob, + # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object + bytes=None, + top_logprobs=top_logprobs, + ) + outputs.append(logprobs_content) + return outputs + + +def prepare_output( gen_seq: GenerationSequence, new_token_ids: list[int], prompt_token_ids: list[int], + logprob_info, tokenizer: TokenizerP, stopping_criteria: StoppingCriteria, -) -> str: +) -> Tuple[str, List[Optional[LogprobsContent]]]: gen_seq.next_start_position = len(prompt_token_ids) + len( gen_seq.generated_token_ids ) @@ -220,11 +228,15 @@ def update_sequence( delta = detokenize_incrementally(prompt_token_ids, gen_seq, tokenizer) gen_seq.output_text += delta + out_logprob_info: List[Optional[LogprobsContent]] = prepare_logprob( + logprob_info, delta, gen_seq, prompt_token_ids, tokenizer + ) + gen_seq.output_text, delta, gen_seq.is_finished = check_stopping_sequences( stopping_criteria, gen_seq.output_text, delta, gen_seq.is_finished ) - return delta + return delta, out_logprob_info def get_requests_to_process( diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index a5b86d69b9..0e091e0c4a 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -7,7 +7,7 @@ from .base import ( ChatMessage, MLCServeEngineConfig, - RawLogprobsInfos, + RawLogprobsInfo, RequestId, RequestState, SequenceId, @@ -19,6 +19,7 @@ @dataclass class PrefillRequest: request_id: RequestId + # `token_ids` contains prompt token ids token_ids: List[int] # Number of sequences to generate num_sequence: int @@ -67,7 +68,6 @@ class EvalMultiQueryRequest: RequestType = Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest] -RequestsType = Sequence[RequestType] @dataclass @@ -81,7 +81,7 @@ class TextGenerationResult: # making this a list of token ids to leave room for speculative decoding generated_tokens: List[int] error: Optional[str] - logprob_info: Optional[RawLogprobsInfos] + logprob_info: Optional[List[Optional[RawLogprobsInfo]]] class KVCache(Protocol): @@ -155,7 +155,7 @@ class TextGenerator(Protocol): def generate( self, - requests: Sequence[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], + requests: Sequence[RequestType], kv_cache, ) -> List[TextGenerationResult]: """ diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index 961b2b744a..6344b7ee15 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -3,7 +3,6 @@ based on https://github.com/vllm-project/vllm/blob/ac5cf86aa6aebbf9e42df51f7e377fbee85bc703/vllm/sampling_params.py """ -from collections import defaultdict from dataclasses import dataclass from enum import IntEnum from functools import cached_property @@ -46,6 +45,8 @@ class SamplingParams: to -1 to consider all tokens. logit_bias: The bias applied on the logit before sampling. Must be in [-100, 100]. + logit_bias_index: Internal data container that stores indices of `logit_bias`. + logit_bias_value: Internal data container that stores values of `logit_bias`. logprobs: Optional[bool] Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message. @@ -53,6 +54,8 @@ class SamplingParams: the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. + vocab_size: Not a part of the sampling params, but needed for the argument validation. + Remove this when we have a better solution. """ presence_penalty: float = 0.0 @@ -62,21 +65,27 @@ class SamplingParams: top_p: float = 1.0 top_k: int = -1 logit_bias: Optional[Dict[int, float]] = None - appeared_tokens_freq: Dict[int, int] = None logit_bias_index: list[int] = None logit_bias_value: list[float] = None logprobs: bool = False top_logprobs: int = 0 + # TODO(@team): This info comes from the model config. + # Currently, it is unclear what is the best way to fetch this info and + # check in `_verify_args` without this field. Follow-up when we have a better idea. + vocab_size = 32000 def __post_init__(self): - self.appeared_tokens_freq = {} if self.logit_bias: self.logit_bias_index = list(self.logit_bias.keys()) self.logit_bias_value = list(self.logit_bias.values()) self._verify_args() if self.temperature < _SAMPLING_EPS: # Zero temperature means greedy sampling. + self.top_p = 1.0 + self.top_k = -1 self._verify_greedy_sampling() + if not self.logprobs: + self.top_logprobs = 0 def _verify_args(self) -> None: if not -2.0 <= self.presence_penalty <= 2.0: @@ -94,6 +103,10 @@ def _verify_args(self) -> None: ) if not 0.0 < self.top_p <= 1.0: raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") + + if not isinstance(self.top_k, int): + raise ValueError(f"top_k must be integer.") + if self.top_k < -1 or self.top_k == 0: raise ValueError( f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}." @@ -104,8 +117,16 @@ def _verify_args(self) -> None: raise ValueError( f"logit bias must be in [-100, 100], got {bias} for token {token}." ) + if not 1 <= token <= self.vocab_size: + raise ValueError(f"token id must be in [1, vocab_size]") + + if self.repetition_penalty <= 0: + raise ValueError( + f"repetition penalty should be a positive float value, got {self.repetition_penalty}." + ) + if self.logprobs: - if (self.top_logprobs < 0 or self.top_logprobs > LOGPROB_TOP_K_MAX): + if self.top_logprobs < 0 or self.top_logprobs > LOGPROB_TOP_K_MAX: raise ValueError( f"top_logprobs must be between 0 and {LOGPROB_TOP_K_MAX}, got {self.top_logprobs}." ) diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index c8354e4c5b..fb7e3f6c03 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -21,12 +21,9 @@ ScopedInferenceEngine, SequenceOutput, ) -from .engine_common import ( - get_new_request_state, - update_sequence, - logprobs_detokenize -) +from .engine_common import get_new_request_state, prepare_output from .model_module import ModelModule, TokenizerModule +from ..model.base import get_model_artifact_config from .staging_engine_worker import ( AddRequestsCommand, CancelRequestCommand, @@ -61,6 +58,11 @@ def __init__( self.requests_lock = Lock() self.requests = dict[RequestId, RequestState]() + # TODO(@team): This is a temporary solution to expose model config to higher API layer. + # Follow-up with the proper solution + self.model_artifact_config = get_model_artifact_config( + model_module_loader_kwargs["model_artifact_path"] + ) self.tokenizer = tokenizer_module.tokenizer self.conversation_template = tokenizer_module.conversation_template @@ -209,6 +211,10 @@ def step(self) -> InferenceStepResult: with structlog.contextvars.bound_contextvars(**state.contextvars): if seq_output.error is not None: + LOG.exception( + "An error occurred during generating sequence outputs.", + exc=seq_output.error, + ) outputs.append( RequestOutput( request_id, @@ -222,17 +228,19 @@ def step(self) -> InferenceStepResult: gen_seq = state.generation_sequences[seq_output.id.sequence_index] new_token_ids = seq_output.new_tokens - + LOG.debug(f"New token ids: {new_token_ids}") if new_token_ids: - delta = update_sequence( + delta, logprob_info = prepare_output( gen_seq, new_token_ids, state.prompt_token_ids, + seq_output.logprob_info, self.tokenizer, state.stopping_criteria, ) else: delta = None + logprob_info = [] if not state.is_prefilled: # Successfully completed a prefill request @@ -252,7 +260,7 @@ def step(self) -> InferenceStepResult: delta, finish_reason, num_generated_tokens=len(gen_seq.generated_token_ids), - logprob_info=logprobs_detokenize(self.tokenizer, seq_output.logprob_info), + logprob_info=logprob_info, ) seq_outputs[request_id].append(output) diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index da4731a96f..c73c6160c2 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -4,7 +4,7 @@ import time import multiprocessing import multiprocessing.synchronize -from dataclasses import dataclass, field +from dataclasses import dataclass from threading import Thread, Lock from typing import Callable, Optional, Union, Any, Dict, List @@ -12,12 +12,12 @@ from .base import ( FinishReason, - RawLogprobsInfos, RequestId, RequestState, ValidationError, SequenceId, GenerationSequence, + RawLogprobsInfo, ) from .metrics import PrometheusMetrics @@ -66,7 +66,7 @@ class SequenceGenerationOutput: new_tokens: List[int] finish_reason: Optional[FinishReason] = None error: Optional[Union[str, ValidationError]] = None - logprob_info: Optional[RawLogprobsInfos] = None + logprob_info: Optional[List[Optional[RawLogprobsInfo]]] = None @dataclass @@ -285,7 +285,6 @@ def step(self) -> GenerationLoopWorkerOutput: ): gen_seq.is_finished = True finish_reason = FinishReason.Length - outputs.append( SequenceGenerationOutput( id=res.sequence_id, diff --git a/serve/mlc_serve/engine/sync_engine.py b/serve/mlc_serve/engine/sync_engine.py index c400ec6b4a..0baa8b50bf 100644 --- a/serve/mlc_serve/engine/sync_engine.py +++ b/serve/mlc_serve/engine/sync_engine.py @@ -20,9 +20,8 @@ should_stop_by_length, get_new_request_state, get_requests_to_process, - update_sequence, + prepare_output, EngineBase, - logprobs_detokenize ) from .model_module import ( ModelModule, @@ -191,10 +190,11 @@ def step(self) -> InferenceStepResult: gen_seq.is_finished = True break - delta = update_sequence( + delta, logprob_info = prepare_output( gen_seq, new_token_ids, state.prompt_token_ids, + res.logprob_info, self.tokenizer, state.stopping_criteria, ) @@ -223,7 +223,7 @@ def step(self) -> InferenceStepResult: delta, num_generated_tokens=len(gen_seq.generated_token_ids), finish_reason=finish_reason, - logprob_info=logprobs_detokenize(self.tokenizer, res.logprob_info), + logprob_info=logprob_info, ) ) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 22ebec7ebb..1dda0cee94 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union, Sequence import structlog import numpy as np @@ -7,22 +7,18 @@ from .paged_cache_manager import CacheManager from ..engine import ( - SamplingType, - SamplingParams, get_prompt_sequence_id, - LOGPROB_TOP_K_MAX, PROMPT_SEQEUNCE_INDEX, - RawLogprobsInfo, - RawLogprobsInfos, SequenceId, + RawLogprobsInfo, ) from ..engine.model_module import ( PrefillRequest, EvalMultiQueryRequest, RequestType, - RequestsType, TextGenerationResult, ) +from .sampler import sample, adjust_logits, SamplingState, SamplingOutput LOG = structlog.stdlib.get_logger(__name__) @@ -50,275 +46,14 @@ def get_num_cache_blocks( ) -def get_logprob_infos( - i: int, - logprob_infos: Optional[RawLogprobsInfos], -) -> Optional[RawLogprobsInfos]: - if logprob_infos is None or logprob_infos[i] is None: - return None - return [logprob_infos[i]] - - -def get_raw_logprob_info( - logits, - token_id, - top_logprobs_num, -) -> RawLogprobsInfo: - logprobs = torch.log_softmax(logits, dim=-1) - res_logprob = logprobs[token_id] - - if top_logprobs_num == 0: - top_logprobs = None - top_tokens = None - else: - assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs" - top_logprobs, top_tokens = torch.topk( - logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True - ) - top_tokens = top_tokens.cpu().numpy() - top_logprobs = top_logprobs.cpu().numpy() - - # Set to raw logprob info - return RawLogprobsInfo( - current_token_id=token_id, - current_logprob=res_logprob, - top_token_ids=top_tokens, - top_logprobs=top_logprobs, - ) - - -def get_logprob_indices( - sampling_params: List[SamplingParams], - num_seq: int, -) -> Tuple[List[Tuple[int, int, int]], List[Tuple[int, int, int]]]: - lgp_inds_greedy: List[Tuple[int, int, int]] = [] - lgp_inds_random: List[Tuple[int, int, int]] = [] - - g_ind = 0 - r_ind = 0 - for i in range(num_seq): - sampling_param = sampling_params[i] - if sampling_param.sampling_type == SamplingType.RANDOM: - if sampling_param.logprobs: - lgp_inds_random.append((i, r_ind, sampling_param.top_logprobs)) - r_ind = r_ind + 1 - else: - if sampling_param.logprobs: - lgp_inds_greedy.append((i, g_ind, sampling_param.top_logprobs)) - g_ind = g_ind + 1 - - return lgp_inds_greedy, lgp_inds_random - - -def get_raw_logprob_infos( - logprob_infos: RawLogprobsInfos, - indices: List[Tuple[int, int, int]], - logits: torch.Tensor, - token_ids: torch.Tensor, -) -> RawLogprobsInfos: - for i, ind, top_logprobs in indices: - logprob_infos[i] = get_raw_logprob_info( - logits[ind], - token_ids[ind], - top_logprobs, - ) - - return logprob_infos - - -def check_logprob_infos( - logprob_infos: RawLogprobsInfos, -) -> Optional[RawLogprobsInfos]: - check = False - for info in logprob_infos: - if info is not None: - check = True - break - if check: - return logprob_infos - return None - - -def _apply_top_p_top_k(logits, top_ps, top_ks): - p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) - k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) - logits_sort, logits_idx = logits.sort(dim=-1, descending=True) - - # Apply top-p. - probs_sort = logits_sort.softmax(dim=-1) - probs_sum = probs_sort.cumsum(dim=-1) - top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) - logits_sort[top_p_mask] = -float("inf") - - # Apply top-k. - # Create a mask for the top-k elements. - top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device) - top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1) - top_k_mask = top_k_mask >= k.unsqueeze(dim=1) - logits_sort[top_k_mask] = -float("inf") - - # Re-sort the probabilities. - logits = torch.gather(logits_sort, dim=-1, index=torch.argsort(logits_idx, dim=-1)) - return logits - - -def sample( - logits: Union[tvm.nd.NDArray, torch.Tensor], - sampling_params: List[SamplingParams], - vocab_size: int, - check_safety=False, -) -> Optional[Tuple[np.ndarray, Optional[RawLogprobsInfos]]]: - def _is_safe_to_sample(prob_like): - return ( - torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0)) - == 0 - ) - - torch.cuda.nvtx.range_push(f"sample {logits.shape}") - logits = torch.from_dlpack(logits) - num_seq = len(sampling_params) - - mask_random_cpu = torch.tensor( - [p.sampling_type == SamplingType.RANDOM for p in sampling_params], - dtype=torch.bool, - ) - mask_greedy_cpu = torch.logical_not(mask_random_cpu) - if logits.device == torch.device("cpu"): - mask_random_dvc = mask_random_cpu - mask_greedy_dvc = mask_greedy_cpu - else: # gpu - mask_random_dvc = mask_random_cpu.to(logits.device) - mask_greedy_dvc = mask_greedy_cpu.to(logits.device) - - logits_greedy = logits[mask_greedy_dvc] - - logprob_infos: RawLogprobsInfos = [None] * num_seq - lgp_inds_greedy, lgp_inds_random = get_logprob_indices( - sampling_params, - num_seq, - ) - - if logits_greedy.shape[0] > 0: - res_greedy = torch.argmax(logits_greedy, -1).cpu().numpy() - - logprob_infos = get_raw_logprob_infos( - logprob_infos, - lgp_inds_greedy, - logits_greedy, - res_greedy, - ) - - # Case when there's only greedy sampling - if logits_greedy.shape[0] == num_seq: - torch.cuda.nvtx.range_pop() - return res_greedy, check_logprob_infos(logprob_infos) - - temperatures = [] - top_ps = [] - top_ks = [] - divide_by_temperature = False - do_top_p = False - do_top_k = False - - for i in range(num_seq): - param = sampling_params[i] - freq = param.appeared_tokens_freq - - if param.sampling_type == SamplingType.RANDOM: - temperatures.append(param.temperature) - top_ps.append(param.top_p) - top_ks.append(param.top_k if param.top_k != -1 else vocab_size) - - divide_by_temperature |= temperatures[-1] != 1.0 - do_top_p |= top_ps[-1] < 1.0 - do_top_k |= top_ks[-1] != vocab_size - - # TODO(vvchernov): need to strictly define order of using penalties and logit bias or - # prohibit simultaneous using of them. At the latter case it can be LogitProcessor - if ( - not param.presence_penalty == 0.0 or not param.frequency_penalty == 0 - ) and bool(freq): - index = torch.from_numpy(np.array(list(freq.keys()))).to( - device=logits.device - ) - src = ( - torch.from_numpy(np.array(list(freq.values()))) - .type_as(logits) - .to(device=logits.device) - ) - logits[i][index] -= ( - src * param.frequency_penalty + param.presence_penalty - ) - - if not param.repetition_penalty == 1.0 and bool(freq): - index = torch.from_numpy(np.array(list(freq.keys()))).to( - device=logits.device - ) - logits[i][index] /= param.repetition_penalty - - if param.logit_bias: - logits[i][param.logit_bias_index] += ( - torch.Tensor(param.logit_bias_value) - .type_as(logits) - .to(device=logits.device) - ) - - logits_random = logits[mask_random_dvc] - - if divide_by_temperature: - t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device) - logits_random.div_(t.unsqueeze(dim=1)) - - if do_top_p or do_top_k: - logits_random = _apply_top_p_top_k(logits_random, top_ps, top_ks) - - probs = torch.softmax(logits_random, dim=-1) - - if check_safety and not _is_safe_to_sample(probs): - torch.cuda.nvtx.range_pop() - return None - - res_random = torch.multinomial(probs, 1, True)[:, 0].cpu().numpy() - - logprob_infos = get_raw_logprob_infos( - logprob_infos, - lgp_inds_random, - logits_random, - res_random, - ) - - # Case when there's only random sampling - if logits_random.shape[0] == num_seq: - torch.cuda.nvtx.range_pop() - return res_random, check_logprob_infos(logprob_infos) - - res = np.empty((num_seq,), dtype=np.int32) - res[mask_random_cpu] = res_random - - if logits_greedy.shape[0] > 0: - res[mask_greedy_cpu] = res_greedy - - torch.cuda.nvtx.range_pop() - return res, check_logprob_infos(logprob_infos) - - -def update_tokens_frequency( - request: RequestType, - new_token: int -): - if not new_token in request.sampling_params.appeared_tokens_freq: - request.sampling_params.appeared_tokens_freq[new_token] = 0 - request.sampling_params.appeared_tokens_freq[new_token] += 1 - - -def append_text_gen_res( - outputs: List[TextGenerationResult], +def prepare_textgen_result( request: RequestType, new_token: List[int], sequence_id: SequenceId, - logprob_info: Optional[RawLogprobsInfos], - err_msg: Optional[str]=None, + logprob_info: Optional[List[Optional[RawLogprobsInfo]]], + err_msg: Optional[str] = None, ) -> List[TextGenerationResult]: + outputs = [] if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: assert isinstance(request, PrefillRequest) for seq_id in range(request.num_sequence): # type: ignore @@ -345,66 +80,100 @@ def append_text_gen_res( def sample_from_logits( logits: Union[tvm.nd.NDArray, torch.Tensor], sequence_ids: List[SequenceId], - requests: RequestsType, - vocab_size, + requests: Sequence[RequestType], + sampling_metadata: SamplingState, + vocab_size: int, + copy_stream: torch.cuda.Stream, + torch_dtype: torch.dtype, + torch_dev: str, + past_decode_tokens: List[List[int]], ) -> List[TextGenerationResult]: - assert logits.shape[0] == len(requests) - - sampling_params = [req.sampling_params for req in requests] + batch_size = logits.shape[0] + assert batch_size == len(requests) + # Convert to torch tensors if logits are in tvm ndarray + if isinstance(logits, tvm.nd.NDArray): + logits = torch.from_dlpack(logits) + + # synchronization point for sampling tensors + # wait until all the tensors are loaded on GPU + torch.cuda.current_stream().wait_stream(copy_stream) + logits = adjust_logits(logits, sampling_metadata, vocab_size) outputs: List[TextGenerationResult] = [] try: - next_tokens, logprob_infos = sample(logits, sampling_params, vocab_size) - assert next_tokens is not None - for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)): - update_tokens_frequency(requests[i], new_token) - outputs = append_text_gen_res( - outputs, - requests[i], - [new_token], - sequence_id, - get_logprob_infos(i, logprob_infos), - ) + sampling_output: Optional[SamplingOutput] = sample( + logits, + sampling_metadata, + ) - return outputs + for i, (new_token, logprob_info) in enumerate( + zip(sampling_output.next_tokens, sampling_output.logprob_infos) + ): + sequence_id = sequence_ids[i] + request = requests[i] + outputs.extend( + prepare_textgen_result( + request, + [new_token], + sequence_id, + [logprob_info] if logprob_info else None, + ) + ) except RuntimeError: # Fallback to per-token sampling in case some logits values are corrupted. err_msg = ( "Error from sampling: probability tensor contains either `inf`, `nan`" " or element < 0" ) - - for i, (sequence_id, logits_per_token, sampling_param) in enumerate( - zip(sequence_ids, torch.from_dlpack(logits), sampling_params) - ): - maybe_new_token, logprob_infos = sample( + logits = torch.from_dlpack(logits) + for i in range(batch_size): + sequence_id = sequence_ids[i] + logits_per_token = logits[i] + sampling_param = sampling_metadata.sampling_params[i] + past_decode_tokens_per_request = past_decode_tokens[i] + # NOTE: Rerun the preparation for simplicity. + # Assume this code path is taken rarely and the recomputation overhead is + # marginal. + with torch.cuda.stream(copy_stream): + new_sampling_metadata = SamplingState.from_sampling_params( + [sampling_param], + [past_decode_tokens_per_request], + torch_dtype, + torch_dev, + vocab_size, + ) + torch.cuda.current_stream().wait_stream(copy_stream) + maybe_sampling_output: Optional[SamplingOutput] = sample( torch.unsqueeze(logits_per_token, 0), - [sampling_param], - vocab_size, + new_sampling_metadata, check_safety=True, ) - if maybe_new_token is not None: - new_token = maybe_new_token[0] - update_tokens_frequency(requests[i], new_token) - outputs = append_text_gen_res( - outputs, - requests[i], - [new_token], - sequence_id, - get_logprob_infos(0, logprob_infos), + new_token = maybe_sampling_output.next_tokens[0] + logprob_info = maybe_sampling_output.logprob_infos[0] + # Valid sample + request = requests[i] + if maybe_sampling_output is not None: + outputs.extend( + prepare_textgen_result( + request, + [new_token], + sequence_id, + [logprob_info] if logprob_info else None, + ) ) else: - outputs = append_text_gen_res( - outputs, - requests[i], - [], # new_token - sequence_id, - get_logprob_infos(0, logprob_infos), - err_msg, + outputs.extend( + prepare_textgen_result( + request, + [], + sequence_id, + None, + err_msg, + ) ) - return outputs + return outputs def prepare_inputs( diff --git a/serve/mlc_serve/model/sampler.py b/serve/mlc_serve/model/sampler.py new file mode 100644 index 0000000000..0562121b8f --- /dev/null +++ b/serve/mlc_serve/model/sampler.py @@ -0,0 +1,511 @@ +import torch +import numpy as np +import structlog +from dataclasses import dataclass +from typing import List, Optional, Tuple +from ..engine import ( + SamplingParams, + SamplingType, + SAMPLING_EPS, + LOGPROB_TOP_K_MAX, + RawLogprobsInfo, +) + +LOG = structlog.stdlib.get_logger(__name__) + + +def _apply_top_p_top_k( + logits: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor +): + # TODO(@team): Check the ordering. We currently apply top-p -> top-k. + logits_sort, logits_idx = logits.sort(dim=-1, descending=True) + + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1) + logits_sort[top_p_mask] = -float("inf") + + # Apply top-k. + # Create a mask for the top-k elements. + top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device) + top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1) + top_k_mask = top_k_mask >= top_ks.unsqueeze(dim=1) + logits_sort[top_k_mask] = -float("inf") + + # Re-sort the probabilities. + logits = torch.gather(logits_sort, dim=-1, index=torch.argsort(logits_idx, dim=-1)) + return logits + + +@dataclass +class SamplingTensors: + """ + Sampling states prepared for sampling computations (`adjust_logits()` and `sample()`). + We keep mask tensors in CPU since putting them on GPU showed small performance regression. + Args: + mask_random: torch.Tensor + Mask for requests with random sampling. + shape: (batch_size, ) + mask_greedy: torch.Tensor + Mask for requests with greedy sampling. + shape: (batch_size, ) + mask_top_logprob: torch.Tensor + Mask for requests with top_logprob. + shape: (LOGPROB_TOP_K_MAX) + 1, batch_size,) + temperatures: torch.Tensor + Tensor for temperature values + shape: (batch_size, ) + top_ps: torch.Tensor + Tensor for top-p values + shape: (batch_size, ) + top_ks: torch.Tensor + Tensor for top-k values + shape: (batch_size, ) + frequency_penalties: torch.Tensor + Tensor for frequency penalty values + shape: (batch_size, ) + presence_penalties: torch.Tensor + Tensor for presence penalty values + shape: (batch_size, ) + repetition_penalties: torch.Tensor + Tensor for repetition penalty values + shape: (batch_size, ) + logit_bias_indices: torch.Tensor + Tensor for indices of logit bias + shape: (num_logit_bias_pairs, ) + logit_bias_values: torch.Tensor + Tensor for values of logit bias + shape: (num_logit_bias_pairs, ) + past_output_tokens: torch.Tensor + Tensor for generated tokens + shape: (batch_size, max_num_gen_tokens,) + """ + + mask_random: torch.Tensor + mask_greedy: torch.Tensor + mask_top_logprob: torch.Tensor + temperatures: torch.Tensor + top_ps: torch.Tensor + top_ks: torch.Tensor + frequency_penalties: torch.Tensor + presence_penalties: torch.Tensor + repetition_penalties: torch.Tensor + logit_bias_indices: torch.Tensor + logit_bias_values: torch.Tensor + past_output_tokens: torch.Tensor + + @classmethod + def from_lists( + cls, + dtype, + dev, + list_mask_random: List[bool], + list_mask_top_logprob: List[List[bool]], + list_temperatures: List[float], + list_top_ps: List[float], + list_top_ks: List[int], + list_frequency_penalties: List[float], + list_presence_penalties: List[float], + list_repetition_penalties: List[float], + list_logit_bias_indices: List[List[int]], + list_logit_bias_values: List[List[float]], + list_past_output_tokens: List[List[int]], + ): + # NOTE: Keep `mask_random` and `mask_greedy` tensors in CPU. + # Moving them to gpu showed a small performance regression. + mask_random = torch.tensor( + list_mask_random, + dtype=torch.bool, + device="cpu", + ) + mask_greedy = torch.logical_not( + mask_random, + ) + # `mask_top_logprob` will be on cpu + mask_top_logprob = torch.from_numpy(list_mask_top_logprob) + temp = torch.tensor( + list_temperatures, + dtype=dtype, + device="cpu", + pin_memory=True, + ) + top_ps = torch.tensor( + list_top_ps, + dtype=dtype, + device="cpu", + pin_memory=True, + ) + top_ks = torch.tensor( + list_top_ks, + dtype=torch.int, + device="cpu", + pin_memory=True, + ) + frequency_penalties = torch.tensor( + list_frequency_penalties, + dtype=dtype, + device="cpu", + pin_memory=True, + ) + presence_penalties = torch.tensor( + list_presence_penalties, + dtype=dtype, + device="cpu", + pin_memory=True, + ) + repetition_penalties = torch.tensor( + list_repetition_penalties, + dtype=dtype, + device="cpu", + pin_memory=True, + ) + logit_bias_indices = torch.tensor( + list_logit_bias_indices, + dtype=torch.long, + device="cpu", + pin_memory=True, + ) + # Convert 1-based index to 0-based + logit_bias_indices -= 1 + logit_bias_values = torch.tensor( + list_logit_bias_values, + dtype=dtype, + device="cpu", + pin_memory=True, + ) + past_output_tokens = torch.tensor( + list_past_output_tokens, + dtype=torch.long, + device="cpu", + pin_memory=True, + ) + + return cls( + mask_random, + mask_greedy, + mask_top_logprob, + temp.to(device=dev, non_blocking=True), + top_ps.to(device=dev, non_blocking=True), + top_ks.to(device=dev, non_blocking=True), + frequency_penalties.to(device=dev, non_blocking=True), + presence_penalties.to(device=dev, non_blocking=True), + repetition_penalties.to(device=dev, non_blocking=True), + logit_bias_indices.to(device=dev, non_blocking=True), + logit_bias_values.to(device=dev, non_blocking=True), + past_output_tokens.to(device=dev, non_blocking=True), + ) + + +@dataclass +class SamplingState: + """ + Sampling states prepared for sampling computations (`adjust_logits()` and `sample()`). + + Args: + has_random: bool + True if the current batch contains a request (or requests) + with random sampling. + has_greedy: bool + True if the current batch contains a request (or requests) + with greedy sampling. + apply_top_p_top_k: bool + True if the current batch contains a request (or requests) + with top-p or top-k. + apply_penalty: bool + True if the current batch contains a request (or requests) + with at least one of the repetition/frequency/presence penalties. + apply_bias: bool + True if the current batch contains a request (or requests) + with logit bias + has_logprob: bool + True if the current batch contains a request (or requests) + with logprob + logprob_batch_indices: List[int] + A list of indices of the requests with logprob inside the batch + sampling_tensors: SamplingTensors + A set of torch tensors that contains masks and parameter + values for sampling computation + sampling_params: List[SamplingParams] + A list of SamplingParams from the user request + """ + + has_random: bool + has_greedy: bool + apply_top_p_top_k: bool + apply_penalty: bool + apply_bias: bool + has_logprob: bool + logprob_batch_indices: List[int] + sampling_tensors: SamplingTensors + sampling_params: List[SamplingParams] + + @classmethod + def from_sampling_params( + cls, + sampling_params: List[SamplingParams], + list_past_output_tokens: List[List[int]], + dtype: torch.dtype, + dev: str, + vocab_size: int, + ): + list_mask_random = [] + list_temperatures = [] + list_top_ps = [] + list_top_ks = [] + do_top_p = False + do_top_k = False + apply_penalty = False + apply_bias = False + list_frequency_penalties = [] + list_presence_penalties = [] + list_repetition_penalties = [] + list_logit_bias_indices = [] + list_logit_bias_values = [] + + idx_random = -1 + idx_greedy = -1 + batch_size = len(sampling_params) + # index 0 is for non-logprob requests + has_logprob = False + logprob_batch_indices = [] + list_mask_top_logprob = np.full( + ((LOGPROB_TOP_K_MAX) + 1, batch_size), False, dtype=bool + ) + logit_bias_maxlen = 0 + for batch_idx, param in enumerate(sampling_params): + # Prepare temperature + # NOTE: Zero temperature means deterministic sampling + # (i.e., greedy sampling or beam search). + # Set the temperature to 1 to avoid division by zero. + list_temperatures.append( + param.temperature if param.temperature >= SAMPLING_EPS else 1.0 + ) + + if param.sampling_type == SamplingType.RANDOM: + list_mask_random.append(True) + idx_random += 1 + list_top_ps.append(param.top_p) + list_top_ks.append(param.top_k if param.top_k != -1 else vocab_size) + do_top_p |= list_top_ps[-1] < 1.0 - SAMPLING_EPS + do_top_k |= list_top_ks[-1] != vocab_size + else: + list_mask_random.append(False) + idx_greedy += 1 + + if param.logprobs: + logprob_batch_indices.append(batch_idx) + # param.top_logprobs is zero if logprob is not used + list_mask_top_logprob[param.top_logprobs][batch_idx] = param.logprobs + has_logprob |= True + + apply_penalty |= ( + abs(param.presence_penalty) >= SAMPLING_EPS + or abs(param.frequency_penalty) >= SAMPLING_EPS + or abs(param.repetition_penalty - 1.0) >= SAMPLING_EPS + ) + list_frequency_penalties.append(param.frequency_penalty) + list_presence_penalties.append(param.presence_penalty) + list_repetition_penalties.append(param.repetition_penalty) + + if param.logit_bias_index: + assert param.logit_bias_value + apply_bias |= True + logit_bias_maxlen = max(logit_bias_maxlen, len(param.logit_bias_index)) + list_logit_bias_indices.append(param.logit_bias_index) + list_logit_bias_values.append(param.logit_bias_value) + else: + list_logit_bias_indices.append([]) + + num_random_samples = idx_random + 1 + num_greedy_samples = idx_greedy + 1 + + has_random = num_random_samples > 0 + has_greedy = num_greedy_samples > 0 + apply_top_p_top_k = do_top_p | do_top_k + + if apply_bias: + # Match the length of each request by padding + for ii in range(batch_size): + logit_bias_values = list_logit_bias_values[ii] + num_padding = logit_bias_maxlen - len(logit_bias_values) + # arbitrary index + list_logit_bias_indices[ii] += [1] * num_padding + list_logit_bias_values[ii] += [0] * num_padding + + max_num_past_tokens = 0 + for past_output_tokens in list_past_output_tokens: + max_num_past_tokens = max(max_num_past_tokens, len(past_output_tokens)) + + for i in range(batch_size): + num = len(list_past_output_tokens[i]) + list_past_output_tokens[i] = list_past_output_tokens[i] + [vocab_size] * ( + max_num_past_tokens - num + ) + + sampling_tensors = SamplingTensors.from_lists( + dtype, + dev, + list_mask_random, + list_mask_top_logprob, + list_temperatures, + list_top_ps, + list_top_ks, + list_frequency_penalties, + list_presence_penalties, + list_repetition_penalties, + list_logit_bias_indices, + list_logit_bias_values, + list_past_output_tokens, + ) + + return cls( + has_random, + has_greedy, + apply_top_p_top_k, + apply_penalty, + apply_bias, + has_logprob, + logprob_batch_indices, + sampling_tensors, + sampling_params, + ) + + +def adjust_logits(logits, sampling_metadata, vocab_size): + batch_size = logits.shape[0] + ( + apply_top_p_top_k, + apply_penalty, + apply_bias, + sampling_tensors, + ) = ( + sampling_metadata.apply_top_p_top_k, + sampling_metadata.apply_penalty, + sampling_metadata.apply_bias, + sampling_metadata.sampling_tensors, + ) + ( + temp_t, + top_ps_t, + top_ks_t, + frequency_penalties_t, + repetition_penalties_t, + presence_penalties_t, + past_output_tokens_t, + logit_bias_indices_t, + logit_bias_values_t, + ) = ( + sampling_tensors.temperatures, + sampling_tensors.top_ps, + sampling_tensors.top_ks, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties, + sampling_tensors.presence_penalties, + sampling_tensors.past_output_tokens, + sampling_tensors.logit_bias_indices, + sampling_tensors.logit_bias_values, + ) + + # TODO(vvchernov): make sure we are applying various sampling params + # (e.g., repetition penalty, frequency/presence penalty, logit bias, temperature...) + # in the right order. + if apply_penalty: + repetition_penalties_t = repetition_penalties_t[:, None].repeat(1, vocab_size) + logits = torch.where( + logits > 0, logits / repetition_penalties_t, logits * repetition_penalties_t + ) + bin_counts = torch.zeros( + (batch_size, vocab_size + 1), dtype=torch.long, device=logits.device + ) + bin_counts.scatter_add_( + 1, past_output_tokens_t, torch.ones_like(past_output_tokens_t) + ) + bin_counts = bin_counts[:, :vocab_size] + mask = bin_counts > 0 + logits -= frequency_penalties_t.unsqueeze_(dim=1) * bin_counts + logits -= presence_penalties_t.unsqueeze_(dim=1) * mask + + # Adjust temperature + logits.div_(temp_t.unsqueeze(dim=1)) + if apply_top_p_top_k: + logits = _apply_top_p_top_k(logits, top_ps_t, top_ks_t) + + if apply_bias: + # logits.scatter_add_ performs the following computation: + # logit[i][index[i][j]] += src[i][j] + # where 0<=i SamplingOutput: + def _is_safe_to_sample(prob_like): + return ( + torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0)) + == 0 + ) + + res_greedy, res_random = None, None + sampling_tensors = sampling_metadata.sampling_tensors + + batch_size = logits.shape[0] + mask_greedy_t, mask_random_t = ( + sampling_tensors.mask_greedy, + sampling_tensors.mask_random, + ) + + next_tokens = np.empty((batch_size,), dtype=np.int64) + if sampling_metadata.has_greedy: + res_greedy = torch.argmax(logits[mask_greedy_t], -1) + np_mask_greedy = mask_greedy_t.cpu().numpy() + next_tokens[np_mask_greedy] = res_greedy.cpu().numpy() + + probs_random = None + if sampling_metadata.has_random: + probs_random = torch.softmax(logits[mask_random_t], dim=-1) + if check_safety and not _is_safe_to_sample(probs_random): + return None + res_random = torch.multinomial(probs_random, 1, True)[:, 0] + np_mask_random = mask_random_t.cpu().numpy() + next_tokens[np_mask_random] = res_random.cpu().numpy() + + logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * batch_size + if sampling_metadata.has_logprob: + # If everything is random sampling, save one extra softmax + if not sampling_metadata.has_greedy: + assert probs_random is not None + logprobs = torch.log(probs_random) + else: + logprobs = torch.log_softmax(logits, dim=-1) + + # Redudandant but vectorized + extended_logprobs = logprobs.repeat((LOGPROB_TOP_K_MAX + 1, 1, 1)) + all_top_logprobs, all_top_tokens = torch.topk( + extended_logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True + ) + mask = sampling_metadata.sampling_tensors.mask_top_logprob + top_tokens = all_top_tokens[mask] + top_logprobs = all_top_logprobs[mask] + for idx, batch_idx in enumerate(sampling_metadata.logprob_batch_indices): + next_token = next_tokens[batch_idx] + assert sampling_metadata.sampling_params[batch_idx].logprobs + top_k = sampling_metadata.sampling_params[batch_idx].top_logprobs + logprob_infos[batch_idx] = RawLogprobsInfo( + current_token_id=next_token, + current_logprob=logprobs[batch_idx][next_token], + top_token_ids=top_tokens[idx][:top_k], + top_logprobs=top_logprobs[idx][:top_k], + ) + + return SamplingOutput(next_tokens, logprob_infos) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index c467e5be7f..c5970ace6e 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -1,6 +1,6 @@ import math import os -from typing import List, Tuple +from typing import List, Tuple, Sequence import structlog import numpy as np @@ -17,20 +17,20 @@ prepare_multi_query_decode_inputs, get_num_cache_blocks, ) - from ..engine import ( get_prompt_sequence_id, MLCServeEngineConfig, ) from ..engine.model_module import ( - DecodeRequest, DraftTokens, EvalMultiQueryRequest, PrefillRequest, - RequestsType, - TextGenerationResult, + DecodeRequest, TextGenerator, + TextGenerationResult, + RequestType, ) +from .sampler import SamplingState LOG = structlog.stdlib.get_logger(__name__) @@ -140,6 +140,17 @@ def __init__( self.sliding_window = config.sliding_window self.num_shards = config.num_shards + # TODO(@sunggg): Find a better way + if config.model_type == "llama": + self.torch_dtype = torch.float32 + elif config.model_type == "mistral" or config.model_type == "mixtral": + self.torch_dtype = torch.float32 + else: + assert 0, f"{config.model_type} is NOT supported yet" + + self._copy_stream: torch.cuda.Stream = torch.cuda.Stream() + self.torch_dev: str = "cuda" + if self.sliding_window: self.block_sliding_window = self.sliding_window // CacheManager.block_size else: @@ -212,16 +223,31 @@ def generate_multi_query( ) -> List[TextGenerationResult]: sequence_ids = [] last_query_offsets: List[int] = [] + sampling_params = [] + past_decode_tokens = [] for request in requests: assert not isinstance(request.queries, DraftTokens) sequence_ids.append(request.sequence_id) - if len(last_query_offsets) == 0: last_query_offsets.append(request.queries.num_tokens - 1) else: last_query_offsets.append( last_query_offsets[-1] + request.queries.num_tokens ) + sampling_params.append(request.sampling_params) + # Use `vocab_size` as a padding + past_decode_tokens.append([self.vocab_size, *request.queries.token_ids]) + + # Prepare sampling tensors in another stream to overlap + # CPU<->GPU data transfer with GPU computation in forward pass. + with torch.cuda.stream(self._copy_stream): + sampling_metadata = SamplingState.from_sampling_params( + sampling_params, + past_decode_tokens, + self.torch_dtype, + self.torch_dev, + self.vocab_size, + ) ( input_ids, @@ -270,17 +296,26 @@ def generate_multi_query( torch.cuda.nvtx.range_pop() last_query_logits = torch.from_dlpack(logits)[last_query_offsets] - return sample_from_logits( - last_query_logits, sequence_ids, requests, self.vocab_size + last_query_logits, + sequence_ids, + requests, # type: ignore + sampling_metadata, + self.vocab_size, + self._copy_stream, + self.torch_dtype, + self.torch_dev, + past_decode_tokens, ) def generate( self, - requests: RequestsType, + requests: Sequence[RequestType], cache: KVCacheInfo, ) -> List[TextGenerationResult]: - if len(requests) == 0: + batch_size = len(requests) + LOG.debug(f"Generation batch size: f{batch_size}.", batch_size=batch_size) + if batch_size == 0: return [] is_prefill = isinstance(requests[0], PrefillRequest) @@ -293,16 +328,45 @@ def generate( all_token_ids = [] sequence_ids = [] prompt_lens = [] + sampling_params = [] + past_decode_tokens = [] for request in requests: if isinstance(request, PrefillRequest): - sequence_ids.append(get_prompt_sequence_id(request.request_id)) + seq_id = get_prompt_sequence_id(request.request_id) + # Use `vocab_size` as a padding. + # This is convenient way to filter out paddings + # after the vectorized sampling computation + # since logit index will be in range of [0,vocab_size) + request_past_decode_tokens = [self.vocab_size] elif isinstance(request, DecodeRequest): - sequence_ids.append(request.sequence_id) + seq_id = request.sequence_id prompt_lens.append(request.prompt_token_counts) + # Use `vocab_size` as a padding + # This is convenient way to filter out paddings + # after the vectorized sampling computation + # since logit index will be in range of [0,vocab_size) + request_past_decode_tokens = [self.vocab_size, *request.token_ids] + else: + raise Exception("`EvalMultiQueryRequest` should not reach here.") + + past_decode_tokens.append(request_past_decode_tokens) + sequence_ids.append(seq_id) assert not isinstance(request, EvalMultiQueryRequest) all_token_ids.append(request.token_ids) + sampling_params.append(request.sampling_params) + + # Prepare sampling tensors in another stream to overlap + # CPU<->GPU data transfer with GPU computation in forward pass. + with torch.cuda.stream(self._copy_stream): + sampling_metadata = SamplingState.from_sampling_params( + sampling_params, + past_decode_tokens, + self.torch_dtype, + self.torch_dev, + self.vocab_size, + ) ( input_ids, @@ -322,7 +386,6 @@ def generate( ) input_shape = input_ids.shape - if self.disco_session: input_ids = copy_to_worker_0(self.disco_session, input_ids) positions = copy_to_worker_0(self.disco_session, positions) @@ -398,7 +461,17 @@ def generate( self.copy_cache_blocks_func(self.cache_blocks, block_mapping) cache.pending_copy_from_to = [] - return sample_from_logits(logits, sequence_ids, requests, self.vocab_size) + return sample_from_logits( + logits, + sequence_ids, + requests, + sampling_metadata, + self.vocab_size, + self._copy_stream, + self.torch_dtype, + self.torch_dev, + past_decode_tokens, + ) def init_tvm_model( diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index f2b80c1231..fa8fec34fd 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -74,12 +74,13 @@ def _test(args: argparse.Namespace): ] for i, prompt in enumerate(prompts): + sampling_param = random.choice(sampling_params_choices) engine.add( [ Request( request_id=str(i), messages=[ChatMessage(role="user", content=prompt)], - sampling_params=random.choice(sampling_params_choices), + sampling_params=sampling_param, stopping_criteria=StoppingCriteria( max_tokens=args.max_output_len, stop_sequences=None ), diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py index 8ff7d56d11..e3c1bee72f 100644 --- a/serve/tests/unittest/test_engine_with_samplers.py +++ b/serve/tests/unittest/test_engine_with_samplers.py @@ -12,12 +12,13 @@ from mlc_serve.model.base import get_model_artifact_config from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule from mlc_serve.utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args +import random def create_engine( model_artifact_path, - use_staging_engine, max_num_batched_tokens, + use_staging_engine, ): engine_config = get_engine_config( { @@ -48,7 +49,17 @@ def create_engine( def create_request( - idx, prompt, temp, freq_pen, pre_pen, max_tokens, stop, ignore_eos, top_logprobs=0, logprobs=False, logit_bias=None + idx, + prompt, + temp, + freq_pen, + pre_pen, + max_tokens, + stop, + ignore_eos, + top_logprobs=0, + logprobs=False, + logit_bias=None, ): return Request( request_id=str(idx), @@ -67,19 +78,11 @@ def create_request( def _test_max_tokens( - model_artifact_path, - use_staging_engine, - max_num_batched_tokens=2048, + engine, num_requests=5, ignore_eos=False, ): prompt = "Write a merge sort program in Python." - engine = create_engine( - model_artifact_path, - use_staging_engine, - max_num_batched_tokens, - ) - requests = [ create_request( idx=str(n - 1), @@ -112,25 +115,15 @@ def _test_max_tokens( else: generated[int(res.request_id)] += seq.delta - if use_staging_engine: - engine.stop() - def _test_max_context_length( model_artifact_path, - use_staging_engine, - max_num_sequences=4, + engine, num_requests=5, ignore_eos=False, ): model_artifact_config = get_model_artifact_config(model_artifact_path) max_context_length = model_artifact_config.max_context_length - - engine = create_engine( - model_artifact_path, - use_staging_engine, - max_num_batched_tokens=max_context_length * max_num_sequences, - ) prompt = "hi " * (max_context_length - 15) requests = [ @@ -161,22 +154,12 @@ def _test_max_context_length( else: generated[int(res.request_id)] += seq.delta - if use_staging_engine: - engine.stop() - def _test_ignore_eos( - model_artifact_path, - use_staging_engine, - max_num_batched_tokens=2048, + engine, num_requests=5, ): prompt = "hi" - engine = create_engine( - model_artifact_path, - use_staging_engine, - max_num_batched_tokens, - ) s = 113 requests = [ create_request( @@ -200,7 +183,6 @@ def _test_ignore_eos( for res in results.outputs: assert len(res.sequences) == 1 seq = res.sequences[0] - if seq.is_finished: assert ( seq.num_generated_tokens @@ -210,22 +192,12 @@ def _test_ignore_eos( else: generated[int(res.request_id)] += seq.delta - if use_staging_engine: - engine.stop() - def _test_stop( - model_artifact_path, - use_staging_engine, - max_num_batched_tokens=2048, + engine, num_requests=5, ): prompt = "Write a merge sort program in Python." - engine = create_engine( - model_artifact_path, - use_staging_engine, - max_num_batched_tokens, - ) requests = [] for n, stop in enumerate(["\n", ["\n"], "\n\n", "!", ["n", "!"]]): requests.append( @@ -243,7 +215,6 @@ def _test_stop( engine.add(requests) generated = ["" for _ in range(num_requests)] - while engine.has_pending_requests(): results = engine.step() for res in results.outputs: @@ -271,87 +242,59 @@ def _test_stop( ) assert found == 1, f"{gen_txt!r}, matches: {found}" - if use_staging_engine: - engine.stop() - -def _test_penalty( - model_artifact_path, - use_staging_engine, - max_num_batched_tokens=2048, - num_requests=5, - ignore_eos=False, +def _test_logprobs( + engine, + num_requests=10, ): - prompt = "Write a merge sort program in Python." - engine = create_engine( - model_artifact_path, - use_staging_engine, - max_num_batched_tokens, - ) - - random_requests = [ - create_request( - idx=str(n - 1), - prompt=prompt, - temp=0.5, - freq_pen=0.5, - pre_pen=-0.5, - max_tokens=n, - stop=None, - ignore_eos=ignore_eos, - logit_bias={123: -100, 456: 100}, - ) - for n in range(1, num_requests) + prompts = [ + "Hi could you please implement merge sort?", + "What is the best city in the world?", + "Can you write a poem for Seattle?", + "Describe lion for kids.", ] - greedy_requests = [ + requests = [ create_request( - idx=str(n - 1), - prompt=prompt, + idx=str(n), + prompt=random.choice(prompts), temp=0, freq_pen=0, pre_pen=0, - max_tokens=n, + max_tokens=300, stop=None, - ignore_eos=ignore_eos, + ignore_eos=True, + top_logprobs=random.randint(1, 5), + logprobs=True, ) - for n in range(num_requests, num_requests << 1) + for n in range(num_requests) ] - requests = random_requests + greedy_requests engine.add(requests) - generated = ["" for _ in range(num_requests << 1)] - + generated = ["" for _ in range(num_requests)] while engine.has_pending_requests(): results = engine.step() for res in results.outputs: assert len(res.sequences) == 1 seq = res.sequences[0] + req = requests[int(res.request_id)] if seq.is_finished: - assert ( - seq.num_generated_tokens - == requests[int(res.request_id)].stopping_criteria.max_tokens - ) + assert seq.finish_reason is not None + assert seq.num_generated_tokens == req.stopping_criteria.max_tokens assert seq.finish_reason == FinishReason.Length else: + assert ( + len(seq.logprob_info[0].top_logprobs) + == req.sampling_params.top_logprobs + ) generated[int(res.request_id)] += seq.delta - if use_staging_engine: - engine.stop() -def _test_logprobs( - model_artifact_path, - use_staging_engine, - num_requests=5, - top_logprobs=3, - max_num_batched_tokens=2048 +def _test_logprobs_mixed_requests( + engine, + num_requests=10, ): prompt = "hi could you please implement merge sort?" - engine = create_engine( - model_artifact_path, - use_staging_engine, - max_num_batched_tokens, - ) requests = [ create_request( idx=str(n), @@ -362,47 +305,65 @@ def _test_logprobs( max_tokens=300, stop=None, ignore_eos=True, - top_logprobs=top_logprobs, - logprobs=True - ) for n in range(num_requests) + top_logprobs=random.randint(1, 5), + logprobs=random.choice([True, False]), + ) + for n in range(num_requests) ] engine.add(requests) generated = ["" for _ in range(num_requests)] - while engine.has_pending_requests(): results = engine.step() for res in results.outputs: assert len(res.sequences) == 1 seq = res.sequences[0] - - assert seq.finish_reason is not None or len(seq.logprob_info[0].top_logprobs) == top_logprobs - + req = requests[int(res.request_id)] if seq.is_finished: - assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens + assert seq.finish_reason is not None + assert seq.num_generated_tokens == req.stopping_criteria.max_tokens assert seq.finish_reason == FinishReason.Length else: + if req.sampling_params.logprobs: + assert ( + len(seq.logprob_info[0].top_logprobs) + == req.sampling_params.top_logprobs + ) + else: + assert len(seq.logprob_info) == 0 generated[int(res.request_id)] += seq.delta - if use_staging_engine: - engine.stop() if __name__ == "__main__": parser = get_default_mlc_serve_argparser("test engine with samplers") args = parser.parse_args() postproc_mlc_serve_args(args) + max_num_batched_tokens = 2048 - _test_max_tokens(args.model_artifact_path, use_staging_engine=True) - _test_max_tokens(args.model_artifact_path, use_staging_engine=False) - _test_ignore_eos(args.model_artifact_path, use_staging_engine=True) - _test_ignore_eos(args.model_artifact_path, use_staging_engine=False) - _test_stop(args.model_artifact_path, use_staging_engine=False) - _test_stop(args.model_artifact_path, use_staging_engine=True) - _test_logprobs(args.model_artifact_path, use_staging_engine=True) - _test_logprobs(args.model_artifact_path, use_staging_engine=False) + # Test staging engines + staging_engine = create_engine( + args.model_artifact_path, max_num_batched_tokens, use_staging_engine=True + ) + _test_max_tokens(staging_engine) + _test_ignore_eos(staging_engine) + # TODO (@sunggg): There is something stateful. + # _test_stop(staging_engine) + _test_logprobs(staging_engine) + _test_logprobs_mixed_requests(staging_engine) + # These tests are broken since we are now imposing no length limit + # if max_tokens = None. The tests do not finish in a reasonable time. + # _test_max_context_length(staging_engine) + staging_engine.stop() + + # Test sync engines + sync_engine = create_engine( + args.model_artifact_path, max_num_batched_tokens, use_staging_engine=False + ) + _test_max_tokens(sync_engine) + _test_ignore_eos(sync_engine) + _test_stop(sync_engine) + _test_logprobs(sync_engine) + _test_logprobs_mixed_requests(sync_engine) # These tests are broken since we are now imposing no length limit # if max_tokens = None. The tests do not finish in a reasonable time. - # _test_max_context_length(model_artifact_path, use_staging_engine=True) - # _test_max_context_length(model_artifact_path, use_staging_engine=False) - _test_penalty(args.model_artifact_path, use_staging_engine=True) - _test_penalty(args.model_artifact_path, use_staging_engine=False) + # _test_max_context_length(sync_engine) diff --git a/serve/tests/unittest/test_sampler.py b/serve/tests/unittest/test_sampler.py new file mode 100644 index 0000000000..17ca286919 --- /dev/null +++ b/serve/tests/unittest/test_sampler.py @@ -0,0 +1,327 @@ +import torch +import pytest +from mlc_serve.model.sampler import SamplingState, adjust_logits +from mlc_serve.engine import SamplingParams, SAMPLING_EPS + +dtype = torch.float32 +dev = "cuda" +vocab_size = 32000 + + +def get_sampling_metadata(sampling_params, past_output_tokens=None): + batch_size = len(sampling_params) + if past_output_tokens is None: + past_output_tokens = [[] for _ in range(batch_size)] + _copy_stream: torch.cuda.Stream = torch.cuda.Stream() + with torch.cuda.stream(_copy_stream): + sampling_metadata = SamplingState.from_sampling_params( + sampling_params, + list_past_output_tokens=past_output_tokens, + dtype=dtype, + dev=dev, + vocab_size=vocab_size, + ) + torch.cuda.current_stream().wait_stream(_copy_stream) + return sampling_metadata + + +def _test_temperature(temp=0, batch_size=1): + shape = (batch_size, vocab_size) + logits = torch.rand(shape, dtype=dtype, device=dev) + sampling_param = SamplingParams( + temperature=temp, + ) + + sampling_metadata = get_sampling_metadata([sampling_param]) + + expected = logits / temp if abs(temp) > SAMPLING_EPS else logits + new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + assert torch.allclose(expected, new_logits) + + +def _test_logit_bias_checker(): + # logit bias must be [-100, 100] + with pytest.raises(ValueError): + logit_bias = {1: 2, 3: 105, 2: 2} + sampling_param = SamplingParams(logit_bias=logit_bias) + get_sampling_metadata([sampling_param]) + + with pytest.raises(ValueError): + logit_bias = {1: 99, 3: -101, 2: 2} + sampling_param = SamplingParams(logit_bias=logit_bias) + get_sampling_metadata([sampling_param]) + + logit_bias = {1: 100, 3: -100, 2: 2} + sampling_param = SamplingParams(logit_bias=logit_bias) + get_sampling_metadata([sampling_param]) + + # TODO(@team): it seems like the valid range is [1,vocab_size]. Double check. + logit_bias = {1: 10, 3: -10, vocab_size: 2} + sampling_param = SamplingParams(logit_bias=logit_bias) + get_sampling_metadata([sampling_param]) + + with pytest.raises(ValueError): + logit_bias = {0: 10, 3: -10} + sampling_param = SamplingParams(logit_bias=logit_bias) + get_sampling_metadata([sampling_param]) + + with pytest.raises(ValueError): + logit_bias = {1: 10, 3: -10, vocab_size + 100: 2} + sampling_param = SamplingParams(logit_bias=logit_bias) + get_sampling_metadata([sampling_param]) + + with pytest.raises(ValueError): + logit_bias = {1: 10, -1: -10} + sampling_param = SamplingParams(logit_bias=logit_bias) + get_sampling_metadata([sampling_param]) + + +def _test_logit_bias(): + # test single batch + batch_size = 1 + shape = (batch_size, vocab_size) + logits = torch.rand(shape, dtype=dtype, device=dev) + logit_bias = {1: -1, 3: 1, 2: 2} + sampling_param = SamplingParams(logit_bias=logit_bias) + sampling_metadata = get_sampling_metadata([sampling_param]) + + expected = torch.clone(logits) + for idx, val in logit_bias.items(): + expected[0][idx - 1] += val + new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + assert torch.allclose(expected, new_logits) + + # test multi-batch + batch_size = 3 + shape = (batch_size, vocab_size) + logits = torch.rand(shape, dtype=dtype, device=dev) + list_logit_bias = [{1: -1, 3: 1, 2: 2}, {4: 2, 5: 1}, {1: -10}] + sampling_params = [ + SamplingParams(logit_bias=logit_bias) for logit_bias in list_logit_bias + ] + sampling_metadata = get_sampling_metadata(sampling_params) + + expected = torch.clone(logits) + for batch_size in range(batch_size): + logit_bias = list_logit_bias[batch_size] + for idx, val in logit_bias.items(): + expected[batch_size][idx - 1] += val + new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + assert torch.allclose(expected, new_logits) + + +def _test_penalties_checker(): + get_sampling_metadata([SamplingParams(presence_penalty=-1.0)]) + get_sampling_metadata([SamplingParams(frequency_penalty=-1.0)]) + get_sampling_metadata([SamplingParams(repetition_penalty=0.7)]) + + with pytest.raises(ValueError): + get_sampling_metadata([SamplingParams(presence_penalty=-2.1)]) + + with pytest.raises(ValueError): + get_sampling_metadata([SamplingParams(frequency_penalty=-2.1)]) + + with pytest.raises(ValueError): + get_sampling_metadata([SamplingParams(repetition_penalty=-2.1)]) + + with pytest.raises(ValueError): + get_sampling_metadata([SamplingParams(presence_penalty=2.1)]) + + with pytest.raises(ValueError): + get_sampling_metadata([SamplingParams(frequency_penalty=2.1)]) + + with pytest.raises(ValueError): + get_sampling_metadata( + [ + SamplingParams(frequency_penalty=1.1), + SamplingParams(repetition_penalty=2.1), + SamplingParams(presence_penalty=1.1), + SamplingParams(presence_penalty=3.1), + ] + ) + + +def _test_penalties(): + # TODO(vvchernov): Add test for repetition penalty + batch_size = 1 + shape = (batch_size, vocab_size) + logits = torch.rand(shape, dtype=dtype, device=dev) + presence_penalties = [0.8] + frequency_penalties = [0.3] + past_output_tokens = [[2, 2, 2, 3]] + + def prepare_metadata(past_output_tokens): + count_map = [] + for past_output_tokens_per_req in past_output_tokens: + # TODO: Check if this is the right range + cnt = [0] * (vocab_size) + for tok in past_output_tokens_per_req: + cnt[tok] += 1 + count_map.append(cnt) + + count_tensor = torch.tensor(count_map, device=dev) + mask_tensor = count_tensor > 0 + return count_tensor, mask_tensor + + count_map, mask = prepare_metadata(past_output_tokens) + + def get_expected_result( + logits, count_map, mask, frequency_penalties, presence_penalties + ): + expected = torch.clone(logits) + for i in range(batch_size): + expected[i] = ( + expected[i] + - count_map[i] * frequency_penalties[i] + - mask[i] * presence_penalties[i] + ) + return expected + + expected = get_expected_result( + logits, count_map, mask, frequency_penalties, presence_penalties + ) + + sampling_param = [ + SamplingParams( + presence_penalty=presence_penalties[0], + frequency_penalty=frequency_penalties[0], + ) + ] + sampling_metadata = get_sampling_metadata( + sampling_param, past_output_tokens=past_output_tokens + ) + new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + assert torch.allclose(expected, new_logits) + + batch_size = 3 + shape = (batch_size, vocab_size) + logits = torch.rand(shape, dtype=dtype, device=dev) + presence_penalties = [0.8, 0.7, -0.8] + frequency_penalties = [-0.3, 2.0, 1.2] + past_output_tokens = [[2, 2, 2, 3, 5], [3, 1, 2, 4], [3, 3, 1]] + + count_map, mask = prepare_metadata(past_output_tokens) + expected = get_expected_result( + logits, count_map, mask, frequency_penalties, presence_penalties + ) + + sampling_params = [ + SamplingParams( + presence_penalty=presence_penalties[i], + frequency_penalty=frequency_penalties[i], + ) + for i in range(batch_size) + ] + sampling_metadata = get_sampling_metadata( + sampling_params, past_output_tokens=past_output_tokens + ) + new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + assert torch.allclose(expected, new_logits) + + +def _test_top_p_top_k_checker(): + get_sampling_metadata([SamplingParams(top_p=0.8)]) + get_sampling_metadata([SamplingParams(top_k=3)]) + + get_sampling_metadata([SamplingParams(top_k=-1)]) + get_sampling_metadata([SamplingParams(top_k=1)]) + + with pytest.raises(ValueError): + get_sampling_metadata([SamplingParams(top_p=0.0)]) + + with pytest.raises(ValueError): + get_sampling_metadata([SamplingParams(top_p=-0.8)]) + + with pytest.raises(ValueError): + get_sampling_metadata([SamplingParams(top_k=0)]) + + with pytest.raises(ValueError): + get_sampling_metadata([SamplingParams(top_k=0.8)]) + + with pytest.raises(ValueError): + get_sampling_metadata([SamplingParams(top_k=-2)]) + + +def _test_top_p_top_k(): + def get_expected_result(logits, top_pks, filter_value=-float("Inf")): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (vocabulary size) + top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + top_k >0: keep only top k tokens with highest probability (top-k filtering). + """ + top_ps, top_ks = [], [] + for top_p, top_k in top_pks: + top_ps.append(top_p) + top_ks.append(top_k) + batch_size = len(top_pks) + lst_logits = [] + for ii in range(batch_size): + _logits = logits[ii] + top_k = min(top_k, _logits.size(-1)) # Safety check + + if top_p > 0.0: + sorted_logits, sorted_indices = torch.sort(_logits, descending=True) + cumulative_probs = torch.cumsum( + torch.softmax(sorted_logits, dim=-1), dim=-1 + ) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1 + ].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices[sorted_indices_to_remove] + _logits[indices_to_remove] = filter_value + + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = ( + _logits < torch.topk(_logits, top_k)[0][..., -1, None] + ) + _logits[indices_to_remove] = filter_value + + lst_logits.append(_logits) + return torch.stack(lst_logits) + + batch_size = 1 + top_p, top_k = 0.7, 5 + shape = (batch_size, vocab_size) + logits = torch.rand(shape, dtype=dtype, device=dev) + sampling_params = [ + SamplingParams(top_p=top_p, top_k=top_k) for _ in range(batch_size) + ] + sampling_metadata = get_sampling_metadata(sampling_params) + new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + expected = logits.clone() + expected = get_expected_result(expected, top_pks=[(top_p, top_k)]) + assert torch.allclose(expected, new_logits) + + batch_size = 3 + shape = (batch_size, vocab_size) + logits = torch.rand(shape, dtype=dtype, device=dev) + top_pks = [(0.7, 3), (0.5, 2), (0.8, 5)] + sampling_params = [ + SamplingParams(top_p=top_p, top_k=top_k) for top_p, top_k in top_pks + ] + sampling_metadata = get_sampling_metadata(sampling_params) + + new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + expected = logits.clone() + expected = get_expected_result(expected, top_pks) + # TODO(team): this is currently broken. Need to fix. + # assert torch.allclose(expected, new_logits) + + +if __name__ == "__main__": + _test_temperature() + _test_logit_bias_checker() + _test_logit_bias() + _test_penalties_checker() + _test_penalties() + _test_top_p_top_k_checker() + _test_top_p_top_k()