Skip to content

Commit

Permalink
wip avoid overlapped requests via requested
Browse files Browse the repository at this point in the history
  • Loading branch information
keyboardAnt committed Sep 16, 2024
1 parent 39f2802 commit d7dac4f
Showing 1 changed file with 20 additions and 21 deletions.
41 changes: 20 additions & 21 deletions poc/dsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,28 +169,34 @@ def __init__(
dtype=torch.int64,
)
self.id_to_mask: Dict[UUID, torch.Tensor] = {}
self.requested_verify = torch.full_like(self.draft_tok_ids, False, dtype=torch.bool)
self.requested_draft = self.requested_verify.clone()
self.pubsub = PubSub()
print("Manager: Initialized with PubSub")
self.timestamp_preemption = 0 # Initialize with 0

async def _send(self, request: Request, queue: asyncio.Queue[Request]) -> None:
self.id_to_mask[request.id] = request.get_mask(
seq_len=self.seq_len, is_draft=queue == self.draft_queue
)
requested = self.requested_verify if queue == self.verify_queue else self.requested_draft
if requested[0, self.id_to_mask[request.id]].all():
print(
f"Manager: Won't send request {request.id} because it covers already requested positions."
)
return
requested[0, self.id_to_mask[request.id]] = True
print(
f"Manager: Enqueuing request {request.id} to {'draft' if queue == self.draft_queue else 'verify'} queue"
)
await queue.put(request)

def _reset(self) -> None:
print("Manager: Resetting draft_scores, draft_tok_ids, and id_to_mask")
self._empty(self.draft_scores)
self._empty(self.draft_tok_ids)
self.draft_scores.fill_(-1)
self.draft_tok_ids.fill_(-1)
self.id_to_mask.clear()

@staticmethod
def _empty(t: torch.Tensor) -> None:
t.fill_(-1)
self.requested_verify.fill_(False)
self.requested_draft.fill_(False)

# @staticmethod
# async def _empty_queue(queue: asyncio.Queue) -> None:
Expand All @@ -204,24 +210,20 @@ def _empty(t: torch.Tensor) -> None:
async def preempt_all(self) -> None:
"""
Broadcasts a preemption message to all workers and clears the request queues.
Updates the last preemption timestamp to the current time.
Assumptions:
- This method has exclusive access to modify the timestamp_preemption.
Guarantees:
- All workers will be notified of the preemption.
- All request queues will be emptied.
- The timestamp_preemption will be updated.
"""
print("Manager: Preempting all workers")
# Update the last preemption timestamp
self.timestamp_preemption = time.time()
# Send preempt message to workers
print("Manager: Sending preempt message to workers")
await self.pubsub.publish(Preemption.create())
print(
f"Manager: Preempt message sent to workers at {self.timestamp_preemption}"
f"Manager: Preempt message sent to workers"
)
# # Clear the queues
# print("Manager: Clearing queues")
Expand All @@ -243,19 +245,15 @@ async def run(self) -> None:
print(
f"Manager: Received response {response}. Will process if not outdated."
)
if response.request_timestamp <= self.timestamp_preemption:
print(f"Manager: Dropping outdated response {response.id}")
if response.id not in self.id_to_mask:
print(
f"Manager: Response {response.id} is not in id_to_mask. Dropping."
)
self.response_queue.task_done()
continue
print(
f"Manager: Processing response {response.id}. (It is not outdated.)"
)
# if response.id not in self.id_to_mask:
# print(
# f"Manager: Response {response.id} is not in id_to_mask. Dropping."
# )
# self.response_queue.task_done()
# continue
mask: torch.Tensor = self.id_to_mask.pop(response.id)
# print(f"Manager: Popped mask {mask} for response {response.id}")
if response.is_draft:
Expand All @@ -272,8 +270,9 @@ async def run(self) -> None:
tok_ids, any_rejected = self.rejection_sampler(response, mask)
tok_ids_padded = torch.full_like(self.tok_ids[0, mask], -1)
tok_ids_padded[: len(tok_ids)] = tok_ids
print(f"Manager: {tok_ids_padded=}. Before assignment: {self.tok_ids[0, mask]=}")
self.tok_ids[0, mask] = tok_ids_padded
print(f"Manager: Updated tok_ids with response {response.id}.")
print(f"Manager: Updated tok_ids with response {response.id}. After assignment: {self.tok_ids[0, mask]=}")
self.response_queue.task_done()
break
if any_rejected:
Expand Down

0 comments on commit d7dac4f

Please sign in to comment.