diff --git a/poc/dsi.py b/poc/dsi.py index ef64a01..7b51b83 100644 --- a/poc/dsi.py +++ b/poc/dsi.py @@ -828,16 +828,16 @@ def generate(model_name: str, prompt: str, max_new_tokens: int) -> str: ### Response: """ - # asyncio.run( - # run( - # verifier_name=verifier_name, - # drafter_name=drafter_name, - # vocab_size=vocab_size, - # lookahead=lookahead, - # prompt=prompt, - # max_new_tokens=max_new_tokens, - # ) - # ) - print(generate(model_name=verifier_name, prompt=prompt, max_new_tokens=max_new_tokens)) + asyncio.run( + run( + verifier_name=verifier_name, + drafter_name=drafter_name, + vocab_size=vocab_size, + lookahead=lookahead, + prompt=prompt, + max_new_tokens=max_new_tokens, + ) + ) + # print(generate(model_name=verifier_name, prompt=prompt, max_new_tokens=max_new_tokens)) print("Script completed")