Skip to content

Commit

Permalink
WIP manager: fix draft tok ids assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
keyboardAnt committed Sep 15, 2024
1 parent a238572 commit e6d5472
Showing 1 changed file with 30 additions and 20 deletions.
50 changes: 30 additions & 20 deletions poc/dsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,13 @@ async def run(self) -> None:
print(f"Manager: prompt's tok_ids.shape: {self.tok_ids.shape}")
print(f"Manager: prompt's tok_ids: {self.tok_ids}")
while (self.tok_ids == -1).any(): # On init or rejection
print(f"Manager: sequence's length: {(self.tok_ids != -1).sum()}")
print(f"Manager: number of empty tok_ids: {(self.tok_ids == -1).sum()}")
print("Manager: Resetting (on init or rejection)")
self._reset()
while (
self.tok_ids == -1
).any(): # continue on acceptance; stop on rejection
print(f"Manager: sequence's length: {(self.tok_ids != -1).sum()}")
print(f"Manager: number of empty tok_ids: {(self.tok_ids == -1).sum()}")
# Select n based on the number of draft tokens waiting for verification
mask_draft_tok_ids_waiting = (self.tok_ids == -1) & (
self.draft_tok_ids != -1
Expand All @@ -256,7 +256,9 @@ async def run(self) -> None:
print(
f"Manager: Sent verify request with n={n}, tok_ids.shape={self.get_tok_ids_with_drafts().shape}, and tok_ids={self.get_tok_ids_with_drafts()}"
)
curr_lookahead: int = min(self.lookahead, (self.tok_ids == -1).sum() - 1)
curr_lookahead: int = min(
self.lookahead, (self.tok_ids == -1).sum() - 1
)
print(f"Manager: The current lookahead is {curr_lookahead}")
if curr_lookahead > 0:
self._send(
Expand All @@ -278,12 +280,12 @@ async def run(self) -> None:
print(
f"Manager: Processing response {response.id}. (It is not outdated.)"
)
if response.id not in self.id_to_mask:
print(
f"Manager: Response {response.id} is not in id_to_mask. Dropping."
)
self.response_queue.task_done()
continue
# if response.id not in self.id_to_mask:
# print(
# f"Manager: Response {response.id} is not in id_to_mask. Dropping."
# )
# self.response_queue.task_done()
# continue
mask: torch.Tensor = self.id_to_mask.pop(response.id)
print(f"Manager: Popped mask {mask} for response {response.id}")
if response.is_draft:
Expand All @@ -292,26 +294,27 @@ async def run(self) -> None:
)
self.draft_scores[0, mask] = response.scores
print(f"Manager: Updated draft scores with response {response.id}")
# self.draft_tok_ids[0, mask] = response.tok_ids[
# 0, -response.scores.shape[1] :
# ]
self.draft_tok_ids[0, mask] = response.tok_ids[
0, -response.scores.shape[1] :
0, : response.tok_ids.shape[1]
]
print(f"Manager: Updated draft tok_ids with response {response.id}")
print(
f"Manager: Updated draft tok_ids with response {response.id}. After the update, the draft tok_ids are {self.draft_tok_ids}"
)
else:
tok_ids: torch.Tensor
any_rejected: bool
tok_ids, any_rejected = self.rejection_sampler(response, mask)
print(
f"Manager: Updated tok_ids with response {response.id}"
)
print(f"Manager: Updated tok_ids with response {response.id}")
tok_ids_padded = torch.full_like(self.tok_ids[0, mask], -1)
tok_ids_padded[: len(tok_ids)] = tok_ids
print(
f"Manager: padded the tok_ids with -1s to the right to match the masked tok_ids: {tok_ids_padded=}"
)
self.tok_ids[0, mask] = tok_ids_padded
print(
f"Manager: Token ids after assignment: {self.tok_ids}"
)
print(f"Manager: Token ids after assignment: {self.tok_ids}")
if any_rejected:
print(f"Manager: Rejected response {response.id}")
self.response_queue.task_done()
Expand Down Expand Up @@ -422,7 +425,11 @@ async def load_model(
cache_dir = os.environ["TRANSFORMERS_CACHE"]
print(f"{self.__class__.__name__}: Loading model {name} with {device_map=}")
self.model = AutoModelForCausalLM.from_pretrained(
name, torch_dtype=dtype, device_map=device_map, cache_dir=cache_dir, load_in_8bit=load_in_8bit
name,
torch_dtype=dtype,
device_map=device_map,
cache_dir=cache_dir,
load_in_8bit=load_in_8bit,
)
self.model.eval()
# if device != cpu:
Expand Down Expand Up @@ -605,7 +612,6 @@ async def perform_task(self, request: Request) -> Response:
print(f"{self.__class__.__name__}: Getting scores for task {request.id}")
device = next(self.model.parameters()).device
tok_ids = request.tok_ids.to(device)
loop = asyncio.get_running_loop()
# Run in executor (i.e., separate thread) to avoid blocking the event loop
scores: torch.Tensor
tok_ids: torch.Tensor
Expand Down Expand Up @@ -859,7 +865,11 @@ async def run(


def generate(
model_name: str, dtype: torch.dtype, load_in_8bit: bool, tok_ids: torch.Tensor, max_new_tokens: int
model_name: str,
dtype: torch.dtype,
load_in_8bit: bool,
tok_ids: torch.Tensor,
max_new_tokens: int,
) -> str:
setup_hf_cache()
print(f"Loading tokenizer for {model_name}")
Expand Down

0 comments on commit e6d5472

Please sign in to comment.