From e7143bf969eca2dacc2e56929fbafed6014a5bc8 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Wed, 16 Aug 2023 11:48:46 -0400 Subject: [PATCH 1/5] Add Sequence.logits() --- outlines/text/generate/sequence.py | 35 +++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/outlines/text/generate/sequence.py b/outlines/text/generate/sequence.py index 594f107a3..8f54f4e7a 100644 --- a/outlines/text/generate/sequence.py +++ b/outlines/text/generate/sequence.py @@ -40,6 +40,36 @@ def is_finished(self, token_ids: torch.LongTensor) -> torch.BoolTensor: def postprocess_completions(self, completions: List[str]) -> List[str]: return completions + def _logits( + self, + num_prompt_tokens: int, + token_ids: torch.LongTensor, + attention_mask: torch.LongTensor, + ) -> torch.FloatTensor: + logits = self.model(token_ids, attention_mask) + logits = self.create_proposal(token_ids[:, num_prompt_tokens:], logits) + return logits + + def logits(self, prompt: Union[str, List[str]]) -> torch.FloatTensor: + """Compute the next-token logits for a given prompt. + + Parameters + ---------- + prompt + The input prompt. + + Returns + ------- + An array containing the logits (unnormalised log probabilities) + for the next token generation. + + """ + token_ids, attention_mask = self.model.tokenizer.encode(prompt) + num_prompt_tokens = token_ids.shape[-1] + token_ids = token_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + return self._logits(num_prompt_tokens, token_ids, attention_mask).squeeze() + def step( self, rng: torch.Generator, @@ -78,9 +108,8 @@ def step( """ num_input_dims = token_ids.ndim - probs = self.model(token_ids, attention_mask) - probs = self.create_proposal(token_ids[:, num_prompt_tokens:], probs) - probs = torch.nn.functional.softmax(probs, dim=-1) + logits = self._logits(num_prompt_tokens, token_ids, attention_mask) + probs = torch.nn.functional.softmax(logits, dim=-1) # Sample `samples`-many new tokens next_token_ids = vectorized_random_choice(rng, probs, samples) From fdb1fdf04138f260992ac53552dc7a8775bba18c Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Thu, 17 Aug 2023 11:38:54 -0400 Subject: [PATCH 2/5] Next token logits method --- outlines/text/generate/sequence.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/outlines/text/generate/sequence.py b/outlines/text/generate/sequence.py index 8f54f4e7a..5a60b4edc 100644 --- a/outlines/text/generate/sequence.py +++ b/outlines/text/generate/sequence.py @@ -40,7 +40,7 @@ def is_finished(self, token_ids: torch.LongTensor) -> torch.BoolTensor: def postprocess_completions(self, completions: List[str]) -> List[str]: return completions - def _logits( + def _next_token_logits( self, num_prompt_tokens: int, token_ids: torch.LongTensor, @@ -50,7 +50,7 @@ def _logits( logits = self.create_proposal(token_ids[:, num_prompt_tokens:], logits) return logits - def logits(self, prompt: Union[str, List[str]]) -> torch.FloatTensor: + def next_token_logits(self, prompt: Union[str, List[str]]) -> torch.FloatTensor: """Compute the next-token logits for a given prompt. Parameters @@ -60,15 +60,15 @@ def logits(self, prompt: Union[str, List[str]]) -> torch.FloatTensor: Returns ------- - An array containing the logits (unnormalised log probabilities) - for the next token generation. + An array of shape `(batch_size, vocab_size)` containing the logits + (unnormalised log probabilities) for the next token generation. """ token_ids, attention_mask = self.model.tokenizer.encode(prompt) num_prompt_tokens = token_ids.shape[-1] token_ids = token_ids.to(self.device) attention_mask = attention_mask.to(self.device) - return self._logits(num_prompt_tokens, token_ids, attention_mask).squeeze() + return self._next_token_logits(num_prompt_tokens, token_ids, attention_mask).squeeze() def step( self, @@ -108,7 +108,7 @@ def step( """ num_input_dims = token_ids.ndim - logits = self._logits(num_prompt_tokens, token_ids, attention_mask) + logits = self._next_token_logits(num_prompt_tokens, token_ids, attention_mask) probs = torch.nn.functional.softmax(logits, dim=-1) # Sample `samples`-many new tokens From bed6edeac085865aa9acfe14020beb5a2bb40bdb Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Thu, 17 Aug 2023 11:42:05 -0400 Subject: [PATCH 3/5] Reformant --- outlines/text/generate/sequence.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/outlines/text/generate/sequence.py b/outlines/text/generate/sequence.py index 5a60b4edc..e6fca8666 100644 --- a/outlines/text/generate/sequence.py +++ b/outlines/text/generate/sequence.py @@ -60,7 +60,7 @@ def next_token_logits(self, prompt: Union[str, List[str]]) -> torch.FloatTensor: Returns ------- - An array of shape `(batch_size, vocab_size)` containing the logits + An array of shape `(batch_size, vocab_size)` containing the logits (unnormalised log probabilities) for the next token generation. """ @@ -68,7 +68,9 @@ def next_token_logits(self, prompt: Union[str, List[str]]) -> torch.FloatTensor: num_prompt_tokens = token_ids.shape[-1] token_ids = token_ids.to(self.device) attention_mask = attention_mask.to(self.device) - return self._next_token_logits(num_prompt_tokens, token_ids, attention_mask).squeeze() + return self._next_token_logits( + num_prompt_tokens, token_ids, attention_mask + ).squeeze() def step( self, From 8d851195e34eef79296e56ab6288cb008859afd5 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Thu, 17 Aug 2023 15:07:12 -0400 Subject: [PATCH 4/5] Remove prob from step --- outlines/text/generate/sequence.py | 13 +++++----- tests/text/generate/test_sequence.py | 37 ++++++++++++++-------------- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/outlines/text/generate/sequence.py b/outlines/text/generate/sequence.py index e6fca8666..aac0c4ba9 100644 --- a/outlines/text/generate/sequence.py +++ b/outlines/text/generate/sequence.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import torch @@ -79,7 +79,7 @@ def step( token_ids: torch.LongTensor, attention_mask: torch.LongTensor, samples: int = 1, - ) -> Tuple[torch.LongTensor, torch.FloatTensor]: + ) -> torch.LongTensor: """Generate one or several tokens that complete the input sequence. The sampling step consists in using a model to generate next-token @@ -103,8 +103,7 @@ def step( ------- A tuple with an array of shape `new_batch_shape + (num_tokens+1,)`that contains the completed sequences (input token ids and generated token - ids) and an array of shape `new_batch_shape + (vocab_size,)` that - contains the next token probabilities. + ids). `new_batch_shape` is computed by removing dimensions of size one in `(samples,) + batch_shape`. @@ -135,7 +134,7 @@ def step( token_ids = torch.atleast_2d(token_ids.squeeze()) probs = torch.atleast_2d(probs.squeeze()) - return token_ids, probs + return token_ids def expand_attention_mask( self, attention_mask: torch.LongTensor @@ -234,7 +233,7 @@ def __call__( num_prompt_tokens = token_ids.shape[-1] if samples > 1: - token_ids, _ = self.step( + token_ids = self.step( rng, num_prompt_tokens, token_ids, attention_mask, samples ) is_finished = self.is_finished(token_ids) @@ -252,7 +251,7 @@ def __call__( if torch.all(is_finished) or num_generated_tokens == self.max_tokens: break - updated_token_ids, _ = self.step( + updated_token_ids = self.step( rng, num_prompt_tokens, token_ids[~is_finished], diff --git a/tests/text/generate/test_sequence.py b/tests/text/generate/test_sequence.py index 8e5f3dd15..6d207992b 100644 --- a/tests/text/generate/test_sequence.py +++ b/tests/text/generate/test_sequence.py @@ -116,9 +116,11 @@ def test_sequence_step(): sequence = Sequence(model) input_ids = torch.tensor([[1, 2]]) - token_ids, probs = sequence.step(rng, 2, input_ids, torch.ones((1, 2))) + token_ids = sequence.step(rng, 2, input_ids, torch.ones((1, 2))) assert torch.equal(token_ids, torch.tensor([[1, 2, 1]])) - assert probs.shape == (1, 4) + + logits_out = sequence._next_token_logits(2, input_ids, torch.ones((1, 2))) + assert logits_out.shape == (1, 4) def test_sequence_step_batch(): @@ -131,9 +133,11 @@ def test_sequence_step_batch(): sequence = Sequence(model) input_ids = torch.tensor([[1, 2], [3, 4]]) - token_ids, probs = sequence.step(rng, 2, input_ids, torch.ones((2, 2))) + token_ids = sequence.step(rng, 2, input_ids, torch.ones((2, 2))) assert torch.equal(token_ids, torch.tensor([[1, 2, 1], [3, 4, 1]])) - assert probs.shape == (2, 4) + + logits_out = sequence._next_token_logits(2, input_ids, torch.ones((2, 2))) + assert logits_out.shape == (2, 4) def test_sequence_step_sample(): @@ -145,9 +149,8 @@ def test_sequence_step_sample(): sequence = Sequence(model) input_ids = torch.tensor([[1, 2]]) - token_ids, probs = sequence.step(rng, 2, input_ids, torch.ones((1, 2)), samples=3) + token_ids = sequence.step(rng, 2, input_ids, torch.ones((1, 2)), samples=3) assert torch.equal(token_ids, torch.tensor([[1, 2, 1], [1, 2, 1], [1, 2, 1]])) - assert probs.shape == (3, 4) def test_sequence_step_sample_batch(): @@ -159,7 +162,7 @@ def test_sequence_step_sample_batch(): sequence = Sequence(model) input_ids = torch.tensor([[1, 2, 1], [3, 4, 1]]) - token_ids, probs = sequence.step(rng, 3, input_ids, torch.ones((2, 3)), samples=3) + token_ids = sequence.step(rng, 3, input_ids, torch.ones((2, 3)), samples=3) assert torch.equal( token_ids, torch.tensor( @@ -170,7 +173,6 @@ def test_sequence_step_sample_batch(): ] ), ) - assert probs.shape == (3, 2, 4) def test_sequence_step_loop(): @@ -183,25 +185,22 @@ def test_sequence_step_loop(): sequence = Sequence(model) input_ids = torch.tensor([[1, 2]]) - token_ids, _ = sequence.step(rng, 2, input_ids, torch.ones((1, 2))) - token_ids, probs = sequence.step(rng, 2, token_ids, torch.ones((1, 3))) + token_ids = sequence.step(rng, 2, input_ids, torch.ones((1, 2))) + token_ids = sequence.step(rng, 2, token_ids, torch.ones((1, 3))) assert torch.equal(token_ids, torch.tensor([[1, 2, 1, 1]])) - assert probs.shape == (1, 4) input_ids = torch.tensor([[1, 2], [3, 4]]) - token_ids, _ = sequence.step(rng, 2, input_ids, torch.ones((2, 2))) - token_ids, probs = sequence.step(rng, 2, token_ids, torch.ones((2, 3))) + token_ids = sequence.step(rng, 2, input_ids, torch.ones((2, 2))) + token_ids = sequence.step(rng, 2, token_ids, torch.ones((2, 3))) assert torch.equal(token_ids, torch.tensor([[1, 2, 1, 1], [3, 4, 1, 1]])) - assert probs.shape == (2, 4) # The number of samples becomes the batch size at the next iteration. input_ids = torch.tensor([[1, 2]]) - token_ids, _ = sequence.step(rng, 2, input_ids, torch.ones((1, 2)), samples=3) - token_ids, probs = sequence.step(rng, 2, token_ids, torch.ones((3, 3))) + token_ids = sequence.step(rng, 2, input_ids, torch.ones((1, 2)), samples=3) + token_ids = sequence.step(rng, 2, token_ids, torch.ones((3, 3))) assert torch.equal( token_ids, torch.tensor([[1, 2, 1, 1], [1, 2, 1, 1], [1, 2, 1, 1]]) ) - assert probs.shape == (3, 4) def test_sequence_step_loop_general(): @@ -213,8 +212,8 @@ def test_sequence_step_loop_general(): sequence = Sequence(model) input_ids = torch.tensor([[1, 2, 1], [3, 4, 1]]) - token_ids, _ = sequence.step(rng, 3, input_ids, torch.ones((1, 3)), samples=3) - result, _ = sequence.step(rng, 3, token_ids, torch.ones((3, 4))) + token_ids = sequence.step(rng, 3, input_ids, torch.ones((1, 3)), samples=3) + result = sequence.step(rng, 3, token_ids, torch.ones((3, 4))) assert result.shape == (3, 2, 5) assert torch.equal( result, From 7bf195e13b9e7281d1b28461b41e8a8defcdc65a Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Thu, 17 Aug 2023 15:57:31 -0400 Subject: [PATCH 5/5] Revert "Remove prob from step" This reverts commit 8d851195e34eef79296e56ab6288cb008859afd5. --- outlines/text/generate/sequence.py | 13 +++++----- tests/text/generate/test_sequence.py | 37 ++++++++++++++-------------- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/outlines/text/generate/sequence.py b/outlines/text/generate/sequence.py index aac0c4ba9..e6fca8666 100644 --- a/outlines/text/generate/sequence.py +++ b/outlines/text/generate/sequence.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import torch @@ -79,7 +79,7 @@ def step( token_ids: torch.LongTensor, attention_mask: torch.LongTensor, samples: int = 1, - ) -> torch.LongTensor: + ) -> Tuple[torch.LongTensor, torch.FloatTensor]: """Generate one or several tokens that complete the input sequence. The sampling step consists in using a model to generate next-token @@ -103,7 +103,8 @@ def step( ------- A tuple with an array of shape `new_batch_shape + (num_tokens+1,)`that contains the completed sequences (input token ids and generated token - ids). + ids) and an array of shape `new_batch_shape + (vocab_size,)` that + contains the next token probabilities. `new_batch_shape` is computed by removing dimensions of size one in `(samples,) + batch_shape`. @@ -134,7 +135,7 @@ def step( token_ids = torch.atleast_2d(token_ids.squeeze()) probs = torch.atleast_2d(probs.squeeze()) - return token_ids + return token_ids, probs def expand_attention_mask( self, attention_mask: torch.LongTensor @@ -233,7 +234,7 @@ def __call__( num_prompt_tokens = token_ids.shape[-1] if samples > 1: - token_ids = self.step( + token_ids, _ = self.step( rng, num_prompt_tokens, token_ids, attention_mask, samples ) is_finished = self.is_finished(token_ids) @@ -251,7 +252,7 @@ def __call__( if torch.all(is_finished) or num_generated_tokens == self.max_tokens: break - updated_token_ids = self.step( + updated_token_ids, _ = self.step( rng, num_prompt_tokens, token_ids[~is_finished], diff --git a/tests/text/generate/test_sequence.py b/tests/text/generate/test_sequence.py index 6d207992b..8e5f3dd15 100644 --- a/tests/text/generate/test_sequence.py +++ b/tests/text/generate/test_sequence.py @@ -116,11 +116,9 @@ def test_sequence_step(): sequence = Sequence(model) input_ids = torch.tensor([[1, 2]]) - token_ids = sequence.step(rng, 2, input_ids, torch.ones((1, 2))) + token_ids, probs = sequence.step(rng, 2, input_ids, torch.ones((1, 2))) assert torch.equal(token_ids, torch.tensor([[1, 2, 1]])) - - logits_out = sequence._next_token_logits(2, input_ids, torch.ones((1, 2))) - assert logits_out.shape == (1, 4) + assert probs.shape == (1, 4) def test_sequence_step_batch(): @@ -133,11 +131,9 @@ def test_sequence_step_batch(): sequence = Sequence(model) input_ids = torch.tensor([[1, 2], [3, 4]]) - token_ids = sequence.step(rng, 2, input_ids, torch.ones((2, 2))) + token_ids, probs = sequence.step(rng, 2, input_ids, torch.ones((2, 2))) assert torch.equal(token_ids, torch.tensor([[1, 2, 1], [3, 4, 1]])) - - logits_out = sequence._next_token_logits(2, input_ids, torch.ones((2, 2))) - assert logits_out.shape == (2, 4) + assert probs.shape == (2, 4) def test_sequence_step_sample(): @@ -149,8 +145,9 @@ def test_sequence_step_sample(): sequence = Sequence(model) input_ids = torch.tensor([[1, 2]]) - token_ids = sequence.step(rng, 2, input_ids, torch.ones((1, 2)), samples=3) + token_ids, probs = sequence.step(rng, 2, input_ids, torch.ones((1, 2)), samples=3) assert torch.equal(token_ids, torch.tensor([[1, 2, 1], [1, 2, 1], [1, 2, 1]])) + assert probs.shape == (3, 4) def test_sequence_step_sample_batch(): @@ -162,7 +159,7 @@ def test_sequence_step_sample_batch(): sequence = Sequence(model) input_ids = torch.tensor([[1, 2, 1], [3, 4, 1]]) - token_ids = sequence.step(rng, 3, input_ids, torch.ones((2, 3)), samples=3) + token_ids, probs = sequence.step(rng, 3, input_ids, torch.ones((2, 3)), samples=3) assert torch.equal( token_ids, torch.tensor( @@ -173,6 +170,7 @@ def test_sequence_step_sample_batch(): ] ), ) + assert probs.shape == (3, 2, 4) def test_sequence_step_loop(): @@ -185,22 +183,25 @@ def test_sequence_step_loop(): sequence = Sequence(model) input_ids = torch.tensor([[1, 2]]) - token_ids = sequence.step(rng, 2, input_ids, torch.ones((1, 2))) - token_ids = sequence.step(rng, 2, token_ids, torch.ones((1, 3))) + token_ids, _ = sequence.step(rng, 2, input_ids, torch.ones((1, 2))) + token_ids, probs = sequence.step(rng, 2, token_ids, torch.ones((1, 3))) assert torch.equal(token_ids, torch.tensor([[1, 2, 1, 1]])) + assert probs.shape == (1, 4) input_ids = torch.tensor([[1, 2], [3, 4]]) - token_ids = sequence.step(rng, 2, input_ids, torch.ones((2, 2))) - token_ids = sequence.step(rng, 2, token_ids, torch.ones((2, 3))) + token_ids, _ = sequence.step(rng, 2, input_ids, torch.ones((2, 2))) + token_ids, probs = sequence.step(rng, 2, token_ids, torch.ones((2, 3))) assert torch.equal(token_ids, torch.tensor([[1, 2, 1, 1], [3, 4, 1, 1]])) + assert probs.shape == (2, 4) # The number of samples becomes the batch size at the next iteration. input_ids = torch.tensor([[1, 2]]) - token_ids = sequence.step(rng, 2, input_ids, torch.ones((1, 2)), samples=3) - token_ids = sequence.step(rng, 2, token_ids, torch.ones((3, 3))) + token_ids, _ = sequence.step(rng, 2, input_ids, torch.ones((1, 2)), samples=3) + token_ids, probs = sequence.step(rng, 2, token_ids, torch.ones((3, 3))) assert torch.equal( token_ids, torch.tensor([[1, 2, 1, 1], [1, 2, 1, 1], [1, 2, 1, 1]]) ) + assert probs.shape == (3, 4) def test_sequence_step_loop_general(): @@ -212,8 +213,8 @@ def test_sequence_step_loop_general(): sequence = Sequence(model) input_ids = torch.tensor([[1, 2, 1], [3, 4, 1]]) - token_ids = sequence.step(rng, 3, input_ids, torch.ones((1, 3)), samples=3) - result = sequence.step(rng, 3, token_ids, torch.ones((3, 4))) + token_ids, _ = sequence.step(rng, 3, input_ids, torch.ones((1, 3)), samples=3) + result, _ = sequence.step(rng, 3, token_ids, torch.ones((3, 4))) assert result.shape == (3, 2, 5) assert torch.equal( result,