diff --git a/poc/dsi.py b/poc/dsi.py index bb02f12..afe703f 100644 --- a/poc/dsi.py +++ b/poc/dsi.py @@ -933,8 +933,7 @@ def decode(tok_ids: torch.Tensor, tokernizer_name: str) -> str: async def main(): print("Script started") print_gpu_memory() - # verifier_name: str = "meta-llama/Meta-Llama-3.1-70B-Instruct" - verifier_name: str = "meta-llama/Meta-Llama-3.1-8B-Instruct" + verifier_name: str = "meta-llama/Meta-Llama-3.1-70B-Instruct" drafter_name: str = "meta-llama/Meta-Llama-3.1-8B-Instruct" verifier_dtype: torch.dtype = torch.float16 drafter_dtype: torch.dtype = torch.float16