Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Dec 25, 2023
1 parent c185936 commit 7dbdbc4
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ async def collect_result_stream(
finish_reasons = [None] * num_sequences
num_prompt_tokens = 0
num_generated_tokens = [0 for _ in range(num_sequences)]
logprob_infos = [[] for _ in range(num_sequences)]
logprob_infos = [[] for _ in range(num_sequences)] # type: ignore
async for res in result_generator:
# TODO: verify that the request cancellation happens after this returns
if res.error:
Expand Down
8 changes: 4 additions & 4 deletions serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import multiprocessing
import queue
from threading import Lock
from typing import Callable, Tuple, List
from typing import Callable, Tuple, List, Dict
from collections import defaultdict

import structlog
Expand Down Expand Up @@ -39,13 +39,13 @@
LOG = structlog.stdlib.get_logger(__name__)


def logprob_detokenize(tokenizer: Tokenizer, logprob_info: Tuple[Tuple, List[Tuple]]) -> Tuple[Tuple, List[Tuple]]:
def logprob_detokenize(tokenizer: Tokenizer, logprob_info: Tuple[Tuple, List[Tuple]]) -> Tuple[Tuple, Dict[str, float]]:
"""Detokenize logprob information"""
if logprob_info is None:
return None
(res, res_logprob), top_tokens = logprob_info
top_tokens = list(top_tokens)
count = {}
count = {} # type: Dict[str, int]
logprob_dict = {}
# dedup duplicates
# Todo: Make sure decode can generate different tokens
Expand Down Expand Up @@ -285,7 +285,7 @@ def step(self) -> InferenceStepResult:
0,
delta=delta,
num_generated_tokens=(
len(state.token_ids) - state.prompt_len
len(state.prompt_token_ids) - state.prompt_len
),
finish_reason=seq_output.finish_reason,
logprob_info=logprob_detokenize(self.tokenizer, seq_output.logprob_info),
Expand Down
12 changes: 6 additions & 6 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,9 @@ def generate(
for index, (sequence_id, new_token) in enumerate(
zip(sequence_ids, next_tokens)
):
if not new_token in requests[i].sampling_params.appeared_tokens_freq:
requests[i].sampling_params.appeared_tokens_freq[new_token] = 0
requests[i].sampling_params.appeared_tokens_freq[new_token] += 1
if not new_token in requests[index].sampling_params.appeared_tokens_freq:
requests[index].sampling_params.appeared_tokens_freq[new_token] = 0
requests[index].sampling_params.appeared_tokens_freq[new_token] += 1
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
for seq_id in range(num_sequences[index]):
outputs.append(
Expand Down Expand Up @@ -579,9 +579,9 @@ def generate(

if maybe_new_token is not None:
new_token = maybe_new_token[0]
if not new_token in requests[i].sampling_params.appeared_tokens_freq:
requests[i].sampling_params.appeared_tokens_freq[new_token] = 0
requests[i].sampling_params.appeared_tokens_freq[new_token] += 1
if not new_token in requests[index].sampling_params.appeared_tokens_freq:
requests[index].sampling_params.appeared_tokens_freq[new_token] = 0
requests[index].sampling_params.appeared_tokens_freq[new_token] += 1
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
for seq_id in range(num_sequences[index]):
outputs.append(
Expand Down

0 comments on commit 7dbdbc4

Please sign in to comment.