From d7dac4ff7eaaa304cf10e890c063e1305820483c Mon Sep 17 00:00:00 2001 From: Nadav Timor Date: Mon, 16 Sep 2024 20:52:25 +0000 Subject: [PATCH] wip avoid overlapped requests via `requested` --- poc/dsi.py | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/poc/dsi.py b/poc/dsi.py index c067c93..3c92d66 100644 --- a/poc/dsi.py +++ b/poc/dsi.py @@ -169,14 +169,22 @@ def __init__( dtype=torch.int64, ) self.id_to_mask: Dict[UUID, torch.Tensor] = {} + self.requested_verify = torch.full_like(self.draft_tok_ids, False, dtype=torch.bool) + self.requested_draft = self.requested_verify.clone() self.pubsub = PubSub() print("Manager: Initialized with PubSub") - self.timestamp_preemption = 0 # Initialize with 0 async def _send(self, request: Request, queue: asyncio.Queue[Request]) -> None: self.id_to_mask[request.id] = request.get_mask( seq_len=self.seq_len, is_draft=queue == self.draft_queue ) + requested = self.requested_verify if queue == self.verify_queue else self.requested_draft + if requested[0, self.id_to_mask[request.id]].all(): + print( + f"Manager: Won't send request {request.id} because it covers already requested positions." + ) + return + requested[0, self.id_to_mask[request.id]] = True print( f"Manager: Enqueuing request {request.id} to {'draft' if queue == self.draft_queue else 'verify'} queue" ) @@ -184,13 +192,11 @@ async def _send(self, request: Request, queue: asyncio.Queue[Request]) -> None: def _reset(self) -> None: print("Manager: Resetting draft_scores, draft_tok_ids, and id_to_mask") - self._empty(self.draft_scores) - self._empty(self.draft_tok_ids) + self.draft_scores.fill_(-1) + self.draft_tok_ids.fill_(-1) self.id_to_mask.clear() - - @staticmethod - def _empty(t: torch.Tensor) -> None: - t.fill_(-1) + self.requested_verify.fill_(False) + self.requested_draft.fill_(False) # @staticmethod # async def _empty_queue(queue: asyncio.Queue) -> None: @@ -204,7 +210,6 @@ def _empty(t: torch.Tensor) -> None: async def preempt_all(self) -> None: """ Broadcasts a preemption message to all workers and clears the request queues. - Updates the last preemption timestamp to the current time. Assumptions: - This method has exclusive access to modify the timestamp_preemption. @@ -212,16 +217,13 @@ async def preempt_all(self) -> None: Guarantees: - All workers will be notified of the preemption. - All request queues will be emptied. - - The timestamp_preemption will be updated. """ print("Manager: Preempting all workers") - # Update the last preemption timestamp - self.timestamp_preemption = time.time() # Send preempt message to workers print("Manager: Sending preempt message to workers") await self.pubsub.publish(Preemption.create()) print( - f"Manager: Preempt message sent to workers at {self.timestamp_preemption}" + f"Manager: Preempt message sent to workers" ) # # Clear the queues # print("Manager: Clearing queues") @@ -243,19 +245,15 @@ async def run(self) -> None: print( f"Manager: Received response {response}. Will process if not outdated." ) - if response.request_timestamp <= self.timestamp_preemption: - print(f"Manager: Dropping outdated response {response.id}") + if response.id not in self.id_to_mask: + print( + f"Manager: Response {response.id} is not in id_to_mask. Dropping." + ) self.response_queue.task_done() continue print( f"Manager: Processing response {response.id}. (It is not outdated.)" ) - # if response.id not in self.id_to_mask: - # print( - # f"Manager: Response {response.id} is not in id_to_mask. Dropping." - # ) - # self.response_queue.task_done() - # continue mask: torch.Tensor = self.id_to_mask.pop(response.id) # print(f"Manager: Popped mask {mask} for response {response.id}") if response.is_draft: @@ -272,8 +270,9 @@ async def run(self) -> None: tok_ids, any_rejected = self.rejection_sampler(response, mask) tok_ids_padded = torch.full_like(self.tok_ids[0, mask], -1) tok_ids_padded[: len(tok_ids)] = tok_ids + print(f"Manager: {tok_ids_padded=}. Before assignment: {self.tok_ids[0, mask]=}") self.tok_ids[0, mask] = tok_ids_padded - print(f"Manager: Updated tok_ids with response {response.id}.") + print(f"Manager: Updated tok_ids with response {response.id}. After assignment: {self.tok_ids[0, mask]=}") self.response_queue.task_done() break if any_rejected: