From cd78e1097be995a80ccedbfe6dae22c0c01c11db Mon Sep 17 00:00:00 2001 From: Nadav Timor Date: Fri, 13 Sep 2024 18:12:39 +0000 Subject: [PATCH] logging --- poc/dsi.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/poc/dsi.py b/poc/dsi.py index a549dd3..ef64a01 100644 --- a/poc/dsi.py +++ b/poc/dsi.py @@ -231,8 +231,10 @@ async def preempt_all(self) -> None: async def run(self) -> None: print("Manager: Starting run") - print(f"Manager: tok_ids: {self.tok_ids}") + print(f"Manager: prompt's tok_ids.shape: {self.tok_ids.shape}") + print(f"Manager: prompt's tok_ids: {self.tok_ids}") while (self.tok_ids == -1).any(): # On init or rejection + print(f"Manager: sequence's length: {(self.tok_ids != -1).sum()}") print(f"Manager: number of empty tok_ids: {(self.tok_ids == -1).sum()}") print("Manager: Resetting (on init or rejection)") self._reset() @@ -252,11 +254,11 @@ async def run(self) -> None: mask_draft_tok_ids_waiting = (self.tok_ids == -1) & ( self.draft_tok_ids != -1 ) - print(f"Manager: number of draft tok_ids waiting for verification: {mask_draft_tok_ids_waiting.sum()}") + n = 1 + max(0, mask_draft_tok_ids_waiting.sum()) await self._send(Request.create(self.tok_ids, n=n), self.verify_queue) print( - f"Manager: Sent verify request with tok_ids={self.tok_ids} and n={n}" + f"Manager: Sent verify request with n={n}, tok_ids.shape={self.tok_ids.shape}, and tok_ids={self.tok_ids}" ) print("Manager: Waiting for response") response: Response = await self.response_queue.get() @@ -480,6 +482,7 @@ async def run(self) -> None: print( f"{self.__class__.__name__}: Processing request with ID {request.id}" ) + print(f"{self.__class__.__name__}: Request {request.id} has {request.tok_ids.shape=}") current_task = asyncio.create_task(self.perform_task(request)) done, pending = await asyncio.wait( {current_task, get_preempt}, return_when=asyncio.FIRST_COMPLETED @@ -825,16 +828,16 @@ def generate(model_name: str, prompt: str, max_new_tokens: int) -> str: ### Response: """ - asyncio.run( - run( - verifier_name=verifier_name, - drafter_name=drafter_name, - vocab_size=vocab_size, - lookahead=lookahead, - prompt=prompt, - max_new_tokens=max_new_tokens, - ) - ) - # print(generate(model_name=verifier_name, prompt=prompt, max_new_tokens=max_new_tokens)) + # asyncio.run( + # run( + # verifier_name=verifier_name, + # drafter_name=drafter_name, + # vocab_size=vocab_size, + # lookahead=lookahead, + # prompt=prompt, + # max_new_tokens=max_new_tokens, + # ) + # ) + print(generate(model_name=verifier_name, prompt=prompt, max_new_tokens=max_new_tokens)) print("Script completed")