diff --git a/poc/actual/main.py b/poc/actual/main.py index c801524..ca6d6e0 100644 --- a/poc/actual/main.py +++ b/poc/actual/main.py @@ -49,19 +49,13 @@ async def get_latency(async_func, *args, **kwargs): @torch.no_grad() async def main(): print("Main started") - # manager_cls = Manager - manager_cls = ManagerNonSI - # verifier_cls = VerifierSlow - verifier_cls = Verifier + verifier_cls = VerifierSlow drafter_cls = Drafter - # verifier_name: str = "meta-llama/Meta-Llama-3.1-70B-Instruct" - verifier_name: str = "meta-llama/Meta-Llama-3.1-8B-Instruct" + verifier_name: str = "meta-llama/Meta-Llama-3.1-70B-Instruct" verifier_load_in_8bit = True verifier_dtype = torch.float16 num_verifiers = 2 - max_new_tokens = 200 - # max_new_tokens = 100 - # max_new_tokens = 20 + max_new_tokens = 50 verifier_device_map_filename = ( "device_map_meta-llama_Meta-Llama-3.1-70B-Instruct_8bit_on_3A40_custom.json" ) @@ -122,15 +116,12 @@ async def run_our_implementation(): return manager.tok_ids latencies = defaultdict(list) - # TODO: Remove this - prompts = prompts[:1] ts = time.time() with logfire.span("Run with args:\n{args}\nat {ts}", args=locals(), ts=ts): for prompt in tqdm(prompts, desc="Prompts"): with logfire.span("Prompt: {prompt}", prompt=prompt): tok_ids = encode(prompt, verifier_name) - # for manager_cls in [Manager, ManagerSI, ManagerNonSI]: - for manager_cls in [ManagerSI]: + for manager_cls in [Manager, ManagerSI, ManagerNonSI]: with logfire.span("{cls_name}", cls_name=manager_cls): manager = manager_cls( draft_queue=draft_queue,