Skip to content

Commit

Permalink
use RequestType
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Feb 2, 2024
1 parent ec7b61d commit e58b7d3
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
7 changes: 3 additions & 4 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ConversationTemplate,
KVCacheManager,
ModelModule,
RequestType,
TextGenerator,
Tokenizer as TokenizerP,
)
Expand Down Expand Up @@ -228,10 +229,8 @@ def update_sequence(

def get_requests_to_process(
current_states: list[RequestState], cache_manager: KVCacheManager
) -> Tuple[
list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], bool, int
]:
requests: list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]] = []
) -> Tuple[list[RequestType], bool, int]:
requests: list[RequestType] = []
# TODO: consider having hybrid batch if the underlying attention kernel supports
# mixing prefill and decode.
is_prompt_batch = any(not state.is_prefilled for state in current_states)
Expand Down
3 changes: 2 additions & 1 deletion serve/mlc_serve/model/dummy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DecodeRequest,
KVCache,
PrefillRequest,
RequestType,
SequenceId,
TextGenerationResult,
)
Expand Down Expand Up @@ -97,7 +98,7 @@ def get_max_new_tokens(self) -> int:
class DummyTextGenerator:
def generate(
self,
requests: list[Union[PrefillRequest, DecodeRequest]],
requests: list[RequestType],
kv_cache: DummyCache,
) -> list[TextGenerationResult]:
result = []
Expand Down
7 changes: 4 additions & 3 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path
import structlog
from typing import List, Union
from typing import List

from .base import get_model_artifact_config
from .paged_cache_manager import CacheManager
Expand All @@ -13,6 +13,7 @@
ModelModule,
PrefillRequest,
EvalMultiQueryRequest,
RequestType,
TextGenerationResult,
TextGenerator,
)
Expand All @@ -26,9 +27,9 @@ def __init__(self, model: TextGenerator):

def generate(
self,
requests: list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]],
requests: List[RequestType],
kv_cache,
) -> list[TextGenerationResult]:
) -> List[TextGenerationResult]:
prefill_requests = []
decode_requests = []
multi_query_decode_requests = []
Expand Down

0 comments on commit e58b7d3

Please sign in to comment.