Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Dec 29, 2023
1 parent a1047a2 commit a822016
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
4 changes: 2 additions & 2 deletions serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
update_sequence,
logprob_detokenize
)
from .model_module import ModelModule, TokenizerModule, Tokenizer
from .model_module import ModelModule, TokenizerModule
from .staging_engine_worker import (
AddRequestsCommand,
CancelRequestCommand,
Expand Down Expand Up @@ -263,7 +263,7 @@ def step(self) -> InferenceStepResult:
0,
delta=delta,
num_generated_tokens=(
len(state.token_ids) - state.prompt_len
len(state.prompt_token_ids) - state.prompt_len
),
finish_reason=seq_output.finish_reason,
logprob_info=logprob_detokenize(self.tokenizer, seq_output.logprob_info),
Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def step(self) -> InferenceStepResult:
0,
delta=delta,
num_generated_tokens=(
len(state.token_ids) - state.prompt_len
len(state.prompt_token_ids) - state.prompt_len
),
logprob_info=res.logprob_info
),
Expand Down
12 changes: 6 additions & 6 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,9 +537,9 @@ def generate(
for i, (sequence_id, new_token) in enumerate(
zip(sequence_ids, next_tokens)
):
if not new_token in requests[i].sampling_params.appeared_tokens_freq:
requests[i].sampling_params.appeared_tokens_freq[new_token] = 0
requests[i].sampling_params.appeared_tokens_freq[new_token] += 1
if not new_token in requests[index].sampling_params.appeared_tokens_freq:
requests[index].sampling_params.appeared_tokens_freq[new_token] = 0
requests[index].sampling_params.appeared_tokens_freq[new_token] += 1
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
for seq_id in range(num_sequences[i]):
outputs.append(
Expand Down Expand Up @@ -581,9 +581,9 @@ def generate(

if maybe_new_token is not None:
new_token = maybe_new_token[0]
if not new_token in requests[i].sampling_params.appeared_tokens_freq:
requests[i].sampling_params.appeared_tokens_freq[new_token] = 0
requests[i].sampling_params.appeared_tokens_freq[new_token] += 1
if not new_token in requests[index].sampling_params.appeared_tokens_freq:
requests[index].sampling_params.appeared_tokens_freq[new_token] = 0
requests[index].sampling_params.appeared_tokens_freq[new_token] += 1
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
for seq_id in range(num_sequences[i]):
outputs.append(
Expand Down

0 comments on commit a822016

Please sign in to comment.