Skip to content

Commit

Permalink
logging
Browse files Browse the repository at this point in the history
  • Loading branch information
keyboardAnt committed Sep 13, 2024
1 parent 236c7bc commit cd78e10
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions poc/dsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,10 @@ async def preempt_all(self) -> None:

async def run(self) -> None:
print("Manager: Starting run")
print(f"Manager: tok_ids: {self.tok_ids}")
print(f"Manager: prompt's tok_ids.shape: {self.tok_ids.shape}")
print(f"Manager: prompt's tok_ids: {self.tok_ids}")
while (self.tok_ids == -1).any(): # On init or rejection
print(f"Manager: sequence's length: {(self.tok_ids != -1).sum()}")
print(f"Manager: number of empty tok_ids: {(self.tok_ids == -1).sum()}")
print("Manager: Resetting (on init or rejection)")
self._reset()
Expand All @@ -252,11 +254,11 @@ async def run(self) -> None:
mask_draft_tok_ids_waiting = (self.tok_ids == -1) & (
self.draft_tok_ids != -1
)
print(f"Manager: number of draft tok_ids waiting for verification: {mask_draft_tok_ids_waiting.sum()}")

n = 1 + max(0, mask_draft_tok_ids_waiting.sum())
await self._send(Request.create(self.tok_ids, n=n), self.verify_queue)
print(
f"Manager: Sent verify request with tok_ids={self.tok_ids} and n={n}"
f"Manager: Sent verify request with n={n}, tok_ids.shape={self.tok_ids.shape}, and tok_ids={self.tok_ids}"
)
print("Manager: Waiting for response")
response: Response = await self.response_queue.get()
Expand Down Expand Up @@ -480,6 +482,7 @@ async def run(self) -> None:
print(
f"{self.__class__.__name__}: Processing request with ID {request.id}"
)
print(f"{self.__class__.__name__}: Request {request.id} has {request.tok_ids.shape=}")
current_task = asyncio.create_task(self.perform_task(request))
done, pending = await asyncio.wait(
{current_task, get_preempt}, return_when=asyncio.FIRST_COMPLETED
Expand Down Expand Up @@ -825,16 +828,16 @@ def generate(model_name: str, prompt: str, max_new_tokens: int) -> str:
### Response:
"""

asyncio.run(
run(
verifier_name=verifier_name,
drafter_name=drafter_name,
vocab_size=vocab_size,
lookahead=lookahead,
prompt=prompt,
max_new_tokens=max_new_tokens,
)
)
# print(generate(model_name=verifier_name, prompt=prompt, max_new_tokens=max_new_tokens))
# asyncio.run(
# run(
# verifier_name=verifier_name,
# drafter_name=drafter_name,
# vocab_size=vocab_size,
# lookahead=lookahead,
# prompt=prompt,
# max_new_tokens=max_new_tokens,
# )
# )
print(generate(model_name=verifier_name, prompt=prompt, max_new_tokens=max_new_tokens))

print("Script completed")

0 comments on commit cd78e10

Please sign in to comment.