Skip to content

Commit

Permalink
fix current lookahead
Browse files Browse the repository at this point in the history
  • Loading branch information
keyboardAnt committed Sep 16, 2024
1 parent d7dac4f commit b734405
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions poc/dsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,22 +283,23 @@ async def run(self) -> None:
@torch.no_grad()
async def send_requests(self) -> None:
# Select n based on the number of draft tokens waiting for verification
mask_draft_tok_ids_waiting = (self.tok_ids == -1) & (
mask_draft_tok_ids_to_verify = (self.tok_ids == -1) & (
self.draft_tok_ids != -1
)
print(
f"Manager: number of draft tokens waiting for verification: {mask_draft_tok_ids_waiting.sum()}"
f"Manager: number of draft tokens waiting for verification: {mask_draft_tok_ids_to_verify.sum()}"
)
n = 1 + max(0, mask_draft_tok_ids_waiting.sum())
n = 1 + max(0, mask_draft_tok_ids_to_verify.sum())
await self._send(
Request.create(self.get_tok_ids_with_drafts(), n=n),
self.verify_queue,
)
print(
f"Manager: Sent verify request with n={n} and tok_ids={self.get_tok_ids_with_drafts()}"
)
mask_draft_tok_ids_to_draft = (self.tok_ids == -1) & (self.draft_tok_ids == -1)
curr_lookahead: int = min(
self.lookahead, (self.tok_ids == -1).sum() - 1
self.lookahead, mask_draft_tok_ids_to_draft.sum() - 1
)
if curr_lookahead > 0:
await self._send(
Expand Down

0 comments on commit b734405

Please sign in to comment.