diff --git a/poc/dsi.py b/poc/dsi.py index f08dc0c..b99490a 100644 --- a/poc/dsi.py +++ b/poc/dsi.py @@ -234,13 +234,13 @@ async def run(self) -> None: print(f"Manager: prompt's tok_ids.shape: {self.tok_ids.shape}") print(f"Manager: prompt's tok_ids: {self.tok_ids}") while (self.tok_ids == -1).any(): # On init or rejection - print(f"Manager: sequence's length: {(self.tok_ids != -1).sum()}") - print(f"Manager: number of empty tok_ids: {(self.tok_ids == -1).sum()}") print("Manager: Resetting (on init or rejection)") self._reset() while ( self.tok_ids == -1 ).any(): # continue on acceptance; stop on rejection + print(f"Manager: sequence's length: {(self.tok_ids != -1).sum()}") + print(f"Manager: number of empty tok_ids: {(self.tok_ids == -1).sum()}") # Select n based on the number of draft tokens waiting for verification mask_draft_tok_ids_waiting = (self.tok_ids == -1) & ( self.draft_tok_ids != -1 @@ -256,7 +256,9 @@ async def run(self) -> None: print( f"Manager: Sent verify request with n={n}, tok_ids.shape={self.get_tok_ids_with_drafts().shape}, and tok_ids={self.get_tok_ids_with_drafts()}" ) - curr_lookahead: int = min(self.lookahead, (self.tok_ids == -1).sum() - 1) + curr_lookahead: int = min( + self.lookahead, (self.tok_ids == -1).sum() - 1 + ) print(f"Manager: The current lookahead is {curr_lookahead}") if curr_lookahead > 0: self._send( @@ -278,12 +280,12 @@ async def run(self) -> None: print( f"Manager: Processing response {response.id}. (It is not outdated.)" ) - if response.id not in self.id_to_mask: - print( - f"Manager: Response {response.id} is not in id_to_mask. Dropping." - ) - self.response_queue.task_done() - continue + # if response.id not in self.id_to_mask: + # print( + # f"Manager: Response {response.id} is not in id_to_mask. Dropping." + # ) + # self.response_queue.task_done() + # continue mask: torch.Tensor = self.id_to_mask.pop(response.id) print(f"Manager: Popped mask {mask} for response {response.id}") if response.is_draft: @@ -292,26 +294,27 @@ async def run(self) -> None: ) self.draft_scores[0, mask] = response.scores print(f"Manager: Updated draft scores with response {response.id}") + # self.draft_tok_ids[0, mask] = response.tok_ids[ + # 0, -response.scores.shape[1] : + # ] self.draft_tok_ids[0, mask] = response.tok_ids[ - 0, -response.scores.shape[1] : + 0, : response.tok_ids.shape[1] ] - print(f"Manager: Updated draft tok_ids with response {response.id}") + print( + f"Manager: Updated draft tok_ids with response {response.id}. After the update, the draft tok_ids are {self.draft_tok_ids}" + ) else: tok_ids: torch.Tensor any_rejected: bool tok_ids, any_rejected = self.rejection_sampler(response, mask) - print( - f"Manager: Updated tok_ids with response {response.id}" - ) + print(f"Manager: Updated tok_ids with response {response.id}") tok_ids_padded = torch.full_like(self.tok_ids[0, mask], -1) tok_ids_padded[: len(tok_ids)] = tok_ids print( f"Manager: padded the tok_ids with -1s to the right to match the masked tok_ids: {tok_ids_padded=}" ) self.tok_ids[0, mask] = tok_ids_padded - print( - f"Manager: Token ids after assignment: {self.tok_ids}" - ) + print(f"Manager: Token ids after assignment: {self.tok_ids}") if any_rejected: print(f"Manager: Rejected response {response.id}") self.response_queue.task_done() @@ -422,7 +425,11 @@ async def load_model( cache_dir = os.environ["TRANSFORMERS_CACHE"] print(f"{self.__class__.__name__}: Loading model {name} with {device_map=}") self.model = AutoModelForCausalLM.from_pretrained( - name, torch_dtype=dtype, device_map=device_map, cache_dir=cache_dir, load_in_8bit=load_in_8bit + name, + torch_dtype=dtype, + device_map=device_map, + cache_dir=cache_dir, + load_in_8bit=load_in_8bit, ) self.model.eval() # if device != cpu: @@ -605,7 +612,6 @@ async def perform_task(self, request: Request) -> Response: print(f"{self.__class__.__name__}: Getting scores for task {request.id}") device = next(self.model.parameters()).device tok_ids = request.tok_ids.to(device) - loop = asyncio.get_running_loop() # Run in executor (i.e., separate thread) to avoid blocking the event loop scores: torch.Tensor tok_ids: torch.Tensor @@ -859,7 +865,11 @@ async def run( def generate( - model_name: str, dtype: torch.dtype, load_in_8bit: bool, tok_ids: torch.Tensor, max_new_tokens: int + model_name: str, + dtype: torch.dtype, + load_in_8bit: bool, + tok_ids: torch.Tensor, + max_new_tokens: int, ) -> str: setup_hf_cache() print(f"Loading tokenizer for {model_name}")