From 592a02c721297ef2580e06a34329b4a0d986f7b7 Mon Sep 17 00:00:00 2001 From: Nadav Timor Date: Fri, 13 Sep 2024 18:35:00 +0000 Subject: [PATCH] fix rejection_sampler --- poc/dsi.py | 26 ++++++++++++++++---------- poc/transformers.ipynb | 27 +++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/poc/dsi.py b/poc/dsi.py index 7b51b83..b15fa3c 100644 --- a/poc/dsi.py +++ b/poc/dsi.py @@ -240,13 +240,14 @@ async def run(self) -> None: self._reset() curr_lookahead: int = min(self.lookahead, (self.tok_ids == -1).sum() - 1) print(f"Manager: The current lookahead is {curr_lookahead}") - await self._send( - Request.create(self.get_tok_ids_with_drafts(), curr_lookahead), - self.draft_queue, - ) - print( - f"Manager: Sent draft request with tok_ids={self.get_tok_ids_with_drafts()} and n={curr_lookahead}" - ) + if curr_lookahead > 0: + await self._send( + Request.create(self.get_tok_ids_with_drafts(), curr_lookahead), + self.draft_queue, + ) + print( + f"Manager: Sent draft request with n={curr_lookahead}, tok_ids.shape={self.get_tok_ids_with_drafts().shape}, and tok_ids={self.get_tok_ids_with_drafts()}" + ) while ( self.tok_ids == -1 ).any(): # continue on acceptance; stop on rejection @@ -256,9 +257,9 @@ async def run(self) -> None: ) n = 1 + max(0, mask_draft_tok_ids_waiting.sum()) - await self._send(Request.create(self.tok_ids, n=n), self.verify_queue) + await self._send(Request.create(self.get_tok_ids_with_drafts(), n=n), self.verify_queue) print( - f"Manager: Sent verify request with n={n}, tok_ids.shape={self.tok_ids.shape}, and tok_ids={self.tok_ids}" + f"Manager: Sent verify request with n={n}, tok_ids.shape={self.get_tok_ids_with_drafts().shape}, and tok_ids={self.get_tok_ids_with_drafts()}" ) print("Manager: Waiting for response") response: Response = await self.response_queue.get() @@ -296,7 +297,7 @@ async def run(self) -> None: tok_ids, any_rejected = self.rejection_sampler(response, mask) self.tok_ids[0, mask] = tok_ids print( - f"Manager: Updated tok_ids with response {response.id} to {tok_ids}" + f"Manager: Updated tok_ids with response {response.id} (accepted {tok_ids.shape[1]} tokens)" ) if any_rejected: print(f"Manager: Rejected response {response.id}") @@ -323,6 +324,11 @@ def rejection_sampler( print( f"Manager: Comparing draft tok_ids {draft_tok_ids} with accepted tok_ids {tok_ids_accepted}:\n{draft_tok_ids == tok_ids_accepted}" ) + if any_rejected: + idx_first_rejected = (draft_tok_ids != tok_ids_accepted).nonzero()[0].item() + print(f"Manager: First rejected token is at index {idx_first_rejected}. Accepting the first {idx_first_rejected} tokens.") + tok_ids_accepted = tok_ids_accepted[:idx_first_rejected+1] + print(f"Manager: New accepted tok_ids: {tok_ids_accepted}") return tok_ids_accepted, any_rejected def get_tok_ids_with_drafts(self) -> torch.Tensor: diff --git a/poc/transformers.ipynb b/poc/transformers.ipynb index 28b0c40..e766ff9 100644 --- a/poc/transformers.ipynb +++ b/poc/transformers.ipynb @@ -408,6 +408,33 @@ "metadata": {}, "outputs": [], "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Rejection sampling" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0])" + ] + }, + "execution_count": 76, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.tensor([ True, False, True, True, True, False]).nonzero()[0]" + ] } ], "metadata": {