Skip to content

Commit

Permalink
Change detokenization to using token ids.
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed Jan 25, 2024
1 parent 86f6fa1 commit 9a29650
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 15 deletions.
1 change: 1 addition & 0 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class RawLogprobsInfo:
current_logprob: float
top_tokens: Optional[np.array]
top_logprobs: Optional[np.array]
previous_tokens: Optional[List[int]]


# TODO(@sunggg): consider transition to something like Pydantic
Expand Down
22 changes: 8 additions & 14 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,13 @@ def logprob_detokenize(
logprob_info.top_logprobs is not None
):
top_tokens = list(zip(logprob_info.top_tokens, logprob_info.top_logprobs))
count: Dict[str, int] = {}
# dedup duplicates
# Todo: Make sure decode can generate different tokens
for top_token, _ in top_tokens:
detokenized = tokenizer.decode(top_token)
if detokenized in count:
count[detokenized] += 1
else:
count[detokenized] = 1
if logprob_info.previous_tokens is None:
logprob_info.previous_tokens = []
for top_token, top_logprob in top_tokens:
detokenized = tokenizer.decode(top_token)
if count[detokenized] != 1:
detokenized = f"{detokenized}_{top_token}"
detokenized = tokenizer.convert_ids_to_tokens(logprob_info.previous_tokens + [top_token])[-1]
LOG.info(f"detokenized: {detokenized}")
top_logprobs.append(TopLogprobs(
token=detokenized,
logprob=float(top_logprob),
Expand All @@ -184,14 +178,14 @@ def logprob_detokenize(

def logprobs_detokenize(
tokenizer: TokenizerP,
logprobs_info: List[Optional[RawLogprobsInfo]],
logprob_info: List[Optional[RawLogprobsInfo]],
) -> Optional[List[Optional[LogprobsContent]]]:
if logprobs_info is None:
if logprob_info is None:
return None

res: List[Optional[LogprobsContent]] = []
for logprob_info in logprobs_info:
res.append(logprob_detokenize(tokenizer, logprob_info))
for info in logprob_info:
res.append(logprob_detokenize(tokenizer, info))

check_all = all([x is None for x in res])
if check_all:
Expand Down
1 change: 1 addition & 0 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def fetch_raw_logprob_infos(
current_logprob=res_logprob,
top_tokens=top_tokens,
top_logprobs=top_logprobs,
previous_tokens=None
))
else:
logprob_infos.append(None)
Expand Down
11 changes: 10 additions & 1 deletion serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PROMPT_SEQEUNCE_INDEX,
get_prompt_sequence_id,
MLCServeEngineConfig,
RawLogprobsInfo,
)
from ..engine.model_module import (
DecodeRequest,
Expand Down Expand Up @@ -84,6 +85,11 @@ def get_tvm_model(config, dev):

return load_disco_module(config.model_artifact_path, lib_path, config.num_shards)

def attach_detokenization_info(logprob_info:RawLogprobsInfo, token_ids: List[int]):
if logprob_info is None:
return None
logprob_info.previous_tokens = token_ids
return logprob_info

def _prepare_inputs(
sequence_ids,
Expand Down Expand Up @@ -326,6 +332,7 @@ def generate(

try:
next_tokens, logprob_infos = sample(logits, sampling_params, self.vocab_size)
current_ids = list(input_ids.numpy())
assert next_tokens is not None
outputs = []
for i, (sequence_id, new_token) in enumerate(
Expand All @@ -341,9 +348,10 @@ def generate(
sequence_id=SequenceId(sequence_id.request_id, seq_id),
generated_tokens=[new_token],
error=None,
logprob_info=[logprob_infos[i]],
logprob_info=[attach_detokenization_info(logprob_infos[i], current_ids)],
)
)
current_ids.append(new_token)
else:
outputs.append(
TextGenerationResult(
Expand All @@ -353,6 +361,7 @@ def generate(
logprob_info=[logprob_infos[i]],
)
)
current_ids.append(new_token)

return outputs
except RuntimeError:
Expand Down

0 comments on commit 9a29650

Please sign in to comment.