Skip to content

Commit

Permalink
handle speciel tokens via shorten masks
Browse files Browse the repository at this point in the history
  • Loading branch information
keyboardAnt committed Sep 22, 2024
1 parent 2362141 commit 192dd5b
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 45 deletions.
88 changes: 43 additions & 45 deletions poc/actual/manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from enum import Enum
from poc.actual.utils import get_shorter_mask
import torch
from poc.actual.pubsub import PubSub
from poc.actual.event import Preemption, Request, Response
Expand All @@ -8,6 +10,26 @@
from uuid import UUID


class GenerationCompletedReason(Enum):

Check failure on line 13 in poc/actual/manager.py

View workflow job for this annotation

GitHub Actions / ci (3.11.9, ubuntu-latest)

Ruff (I001)

poc/actual/manager.py:1:1: I001 Import block is un-sorted or un-formatted
"""
Reason for generation completion.
"""
MAX_TOKENS = "max_tokens"
SPECIAL_TOKEN = "special_token"


class GenerationCompleted(Exception):
"""
Raised when a generation is completed.
"""

def __init__(self, reason: GenerationCompletedReason):
self.reason = reason

def __str__(self):
return f"GenerationCompleted(reason={self.reason})"


class Manager:
"""
Manages the overall system, handling requests, responses, and preemptions.
Expand Down Expand Up @@ -180,11 +202,12 @@ async def run(self) -> None:
)
mask: torch.Tensor = self.id_to_mask.pop(response.id)
if response.is_draft:
self.draft_scores[0, mask] = response.scores
# scores_padded = torch.full_like(self.draft_scores[0, mask], -1)
n = response.scores.shape[1]
# scores_padded[:n] = response.scores
# self.draft_scores[0, mask] = scores_padded
try:
self.draft_scores[:, mask] = response.scores
except RuntimeError as e:
print(f"{self.__class__.__name__}: Error updating draft scores: {e}")
mask = get_shorter_mask(mask_1d=mask, n=n)
self.draft_tok_ids[0, mask] = response.tok_ids[
0, -n:
]
Expand Down Expand Up @@ -276,6 +299,14 @@ def get_tok_ids_with_drafts(self) -> torch.Tensor:
ret[nonempty_mask] = self.tok_ids[nonempty_mask]
return ret

def _crop_length(self, n: int) -> None:
"""
Crops the `tok_ids`, `draft_tok_ids`, and `draft_scores` to the first `n` tokens.
"""
self.tok_ids = self.tok_ids[:, :n]
self.draft_tok_ids = self.draft_tok_ids[:, :n]
self.draft_scores = self.draft_scores[:, :n, :]


class ManagerSI(Manager):
async def run(self) -> None:
Expand All @@ -288,7 +319,7 @@ async def run(self) -> None:
print(
f"{self.__class__.__name__}: number of empty tok_ids: {(self.tok_ids == -1).sum()}"
)
print(f"{self.__class__.__name__}: {self.tok_ids=}")
print(f"{self.__class__.__name__}: tok_ids:\n{self.tok_ids}")
# 1. Draft
mask_draft_tok_ids_to_draft = (self.tok_ids == -1) & (
self.draft_tok_ids == -1
Expand All @@ -303,15 +334,16 @@ async def run(self) -> None:
)
print(f"{self.__class__.__name__}: Waiting for draft response")
response_draft: Response = await self.response_queue.get()
n = response_draft.scores.shape[1]
print(
f"{self.__class__.__name__}: Received draft response {response_draft}."
)
mask: torch.Tensor = self.id_to_mask.pop(response_draft.id)
self.draft_scores[0, mask] = response_draft.scores
# scores_padded = torch.full_like(self.draft_scores[0, mask], -1)
n = response_draft.scores.shape[1]
# scores_padded[:n] = response_draft.scores
# self.draft_scores[0, mask] = scores_padded
try:
self.draft_scores[:, mask] = response_draft.scores
except RuntimeError as e:
print(f"{self.__class__.__name__}: Error updating draft scores: {e}")
mask = get_shorter_mask(mask_1d=mask, n=n)
self.draft_tok_ids[0, mask] = response_draft.tok_ids[0, -n:]
print(
f"{self.__class__.__name__}: Updated draft tok_ids and scores with response {response_draft.id}. After the update, the draft tok_ids are\n{self.draft_tok_ids}"
Expand Down Expand Up @@ -359,41 +391,7 @@ async def run(self) -> None:
print(
f"{self.__class__.__name__}: number of empty tok_ids: {(self.tok_ids == -1).sum()}"
)
print(f"{self.__class__.__name__}: {self.tok_ids=}")
# # 1. Draft
# mask_draft_tok_ids_to_draft = (self.tok_ids == -1) & (
# self.draft_tok_ids == -1
# )
# curr_lookahead: int = min(
# self.lookahead, mask_draft_tok_ids_to_draft.sum() - 1
# )
# if curr_lookahead > 0:
# await self._send(
# Request.create(self.get_tok_ids_with_drafts(), curr_lookahead),
# self.draft_queue,
# )
# print(f"{self.__class__.__name__}: Waiting for draft response")
# response_draft: Response = await self.response_queue.get()
# print(
# f"{self.__class__.__name__}: Received draft response {response_draft}."
# )
# mask: torch.Tensor = self.id_to_mask.pop(response_draft.id)
# self.draft_scores[0, mask] = response_draft.scores
# self.draft_tok_ids[0, mask] = response_draft.tok_ids[
# 0, -response_draft.scores.shape[1] :
# ]
# print(
# f"{self.__class__.__name__}: Updated draft tok_ids and scores with response {response_draft.id}. After the update, the draft tok_ids are {self.draft_tok_ids}"
# )
# self.response_queue.task_done()
# 2. Verify
# mask_draft_tok_ids_to_verify = (self.tok_ids == -1) & (
# self.draft_tok_ids != -1
# )
# print(
# f"{self.__class__.__name__}: number of draft tokens waiting for verification: {mask_draft_tok_ids_to_verify.sum()}"
# )
# n = 1 + max(0, mask_draft_tok_ids_to_verify.sum())
print(f"{self.__class__.__name__}: tok_ids:\n{self.tok_ids}")
await self._send(
Request.create(self.get_tok_ids_with_drafts(), n=1),
self.verify_queue,
Expand Down
58 changes: 58 additions & 0 deletions poc/actual/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,61 @@ def get_queues(
draft_queue = asyncio.Queue(maxsize=1)
response_queue = asyncio.Queue()
return verify_queue, draft_queue, response_queue


def right_pad_like(array_to_pad: torch.Tensor, like: torch.Tensor, dim: int) -> torch.Tensor:
"""
Right pads the `array_to_pad` with -1s to be like the `like` tensor.
The shape of the `like` tensor should be the same as the `array_to_pad` except for the `dim` dimension.
The `dim` dimension of the `like` tensor should be at least as large as the `dim` dimension of the `array_to_pad`.
"""
if array_to_pad.dim() != like.dim():
raise ValueError(f"Tensors must have the same number of dimensions. Got {array_to_pad.dim()} and {like.dim()}")

if dim < 0 or dim >= array_to_pad.dim():
raise ValueError(f"Invalid dimension {dim}. Must be between 0 and {array_to_pad.dim() - 1}")

for i in range(array_to_pad.dim()):
if i != dim and array_to_pad.shape[i] != like.shape[i]:
raise ValueError(f"Shapes must match in all dimensions except {dim}. Mismatch at dimension {i}: {array_to_pad.shape[i]} vs {like.shape[i]}")

if array_to_pad.shape[dim] > like.shape[dim]:
raise ValueError(f"The 'dim' dimension of array_to_pad ({array_to_pad.shape[dim]}) cannot be larger than that of 'like' ({like.shape[dim]})")

n = array_to_pad.shape[dim]
padded = torch.full_like(like, -1)
padded.narrow(dim, 0, n).copy_(array_to_pad)
return padded


def get_shorter_mask(mask_1d: torch.Tensor, n: int) -> torch.Tensor:
"""
The mask is a 1-dimensional boolean tensor of False values, with a subsequence of True values.
This function "shortens" the mask by keeping the first `n` True values and replacing the rest with False values.
Raises an error if `n` is strictly greater than the number of True values in the mask.
Examples:
- if the mask is [False, True, True, True, False] and `n` is 0, the function will return [False, True, True, True, False].
- if the mask is [False, True, True, True, False] and `n` is 1, the function will return [False, True, False, False, False].
- if the mask is [False, True, True, True, False] and `n` is 2, the function will return [False, True, True, False, False].
- if the mask is [False, True, True, True, False] and `n` is 3, the function will return [False, False, False, False, False].
- if the mask is [False, True, True, True, False] and `n` is 4, the function will raise an error.
- if the mask is [True, True, True, False, False] and `n` is 2, the function will return [True, True, False, False, False].
"""
true_count = mask_1d.sum().item() # Count the number of True values in the mask
if n > true_count:
raise ValueError(f"Cannot shorten the mask to keep {n} True values because there are only {true_count} True values in the mask")
if n == true_count:
return mask_1d

new_mask = mask_1d.clone() # Clone the original mask to avoid modifying it directly
if n == 0:
new_mask[:] = False
return new_mask

true_indices = (new_mask == True).nonzero(as_tuple=True)[0]
if n < true_indices.size(0):
new_mask[true_indices[n]:] = False # Set all True values beyond the first `n` to False

return new_mask

0 comments on commit 192dd5b

Please sign in to comment.