diff --git a/outlines/text/generate/sequence.py b/outlines/text/generate/sequence.py index 594f107a3..e6fca8666 100644 --- a/outlines/text/generate/sequence.py +++ b/outlines/text/generate/sequence.py @@ -40,6 +40,38 @@ def is_finished(self, token_ids: torch.LongTensor) -> torch.BoolTensor: def postprocess_completions(self, completions: List[str]) -> List[str]: return completions + def _next_token_logits( + self, + num_prompt_tokens: int, + token_ids: torch.LongTensor, + attention_mask: torch.LongTensor, + ) -> torch.FloatTensor: + logits = self.model(token_ids, attention_mask) + logits = self.create_proposal(token_ids[:, num_prompt_tokens:], logits) + return logits + + def next_token_logits(self, prompt: Union[str, List[str]]) -> torch.FloatTensor: + """Compute the next-token logits for a given prompt. + + Parameters + ---------- + prompt + The input prompt. + + Returns + ------- + An array of shape `(batch_size, vocab_size)` containing the logits + (unnormalised log probabilities) for the next token generation. + + """ + token_ids, attention_mask = self.model.tokenizer.encode(prompt) + num_prompt_tokens = token_ids.shape[-1] + token_ids = token_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + return self._next_token_logits( + num_prompt_tokens, token_ids, attention_mask + ).squeeze() + def step( self, rng: torch.Generator, @@ -78,9 +110,8 @@ def step( """ num_input_dims = token_ids.ndim - probs = self.model(token_ids, attention_mask) - probs = self.create_proposal(token_ids[:, num_prompt_tokens:], probs) - probs = torch.nn.functional.softmax(probs, dim=-1) + logits = self._next_token_logits(num_prompt_tokens, token_ids, attention_mask) + probs = torch.nn.functional.softmax(logits, dim=-1) # Sample `samples`-many new tokens next_token_ids = vectorized_random_choice(rng, probs, samples)