Skip to content

Commit

Permalink
fix rejection_sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
keyboardAnt committed Sep 13, 2024
1 parent e5dc4fb commit 592a02c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 10 deletions.
26 changes: 16 additions & 10 deletions poc/dsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,14 @@ async def run(self) -> None:
self._reset()
curr_lookahead: int = min(self.lookahead, (self.tok_ids == -1).sum() - 1)
print(f"Manager: The current lookahead is {curr_lookahead}")
await self._send(
Request.create(self.get_tok_ids_with_drafts(), curr_lookahead),
self.draft_queue,
)
print(
f"Manager: Sent draft request with tok_ids={self.get_tok_ids_with_drafts()} and n={curr_lookahead}"
)
if curr_lookahead > 0:
await self._send(
Request.create(self.get_tok_ids_with_drafts(), curr_lookahead),
self.draft_queue,
)
print(
f"Manager: Sent draft request with n={curr_lookahead}, tok_ids.shape={self.get_tok_ids_with_drafts().shape}, and tok_ids={self.get_tok_ids_with_drafts()}"
)
while (
self.tok_ids == -1
).any(): # continue on acceptance; stop on rejection
Expand All @@ -256,9 +257,9 @@ async def run(self) -> None:
)

n = 1 + max(0, mask_draft_tok_ids_waiting.sum())
await self._send(Request.create(self.tok_ids, n=n), self.verify_queue)
await self._send(Request.create(self.get_tok_ids_with_drafts(), n=n), self.verify_queue)
print(
f"Manager: Sent verify request with n={n}, tok_ids.shape={self.tok_ids.shape}, and tok_ids={self.tok_ids}"
f"Manager: Sent verify request with n={n}, tok_ids.shape={self.get_tok_ids_with_drafts().shape}, and tok_ids={self.get_tok_ids_with_drafts()}"
)
print("Manager: Waiting for response")
response: Response = await self.response_queue.get()
Expand Down Expand Up @@ -296,7 +297,7 @@ async def run(self) -> None:
tok_ids, any_rejected = self.rejection_sampler(response, mask)
self.tok_ids[0, mask] = tok_ids
print(
f"Manager: Updated tok_ids with response {response.id} to {tok_ids}"
f"Manager: Updated tok_ids with response {response.id} (accepted {tok_ids.shape[1]} tokens)"
)
if any_rejected:
print(f"Manager: Rejected response {response.id}")
Expand All @@ -323,6 +324,11 @@ def rejection_sampler(
print(
f"Manager: Comparing draft tok_ids {draft_tok_ids} with accepted tok_ids {tok_ids_accepted}:\n{draft_tok_ids == tok_ids_accepted}"
)
if any_rejected:
idx_first_rejected = (draft_tok_ids != tok_ids_accepted).nonzero()[0].item()
print(f"Manager: First rejected token is at index {idx_first_rejected}. Accepting the first {idx_first_rejected} tokens.")
tok_ids_accepted = tok_ids_accepted[:idx_first_rejected+1]
print(f"Manager: New accepted tok_ids: {tok_ids_accepted}")
return tok_ids_accepted, any_rejected

def get_tok_ids_with_drafts(self) -> torch.Tensor:
Expand Down
27 changes: 27 additions & 0 deletions poc/transformers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,33 @@
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Rejection sampling"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0])"
]
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.tensor([ True, False, True, True, True, False]).nonzero()[0]"
]
}
],
"metadata": {
Expand Down

0 comments on commit 592a02c

Please sign in to comment.