From 43c037e01c8b81af57911a461f2d291a48bb4d36 Mon Sep 17 00:00:00 2001 From: Jules Gagnon-Marchand Date: Tue, 27 Aug 2024 18:43:40 -0700 Subject: [PATCH 1/3] Include the Prompt in the FSM Pattern --- outlines/generate/api.py | 268 ++++++++++++++++++++++----------------- 1 file changed, 152 insertions(+), 116 deletions(-) diff --git a/outlines/generate/api.py b/outlines/generate/api.py index ad01377c0..03c6eb782 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -125,124 +125,160 @@ def format_sequence(self, sequence: str) -> FormattedOutput: """ return sequence - + def __call__( - self, - prompts: Union[str, List[str]], - max_tokens: Optional[int] = None, - stop_at: Optional[Union[str, List[str]]] = None, - rng: Optional["torch.Generator"] = None, - ) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]: - """Generate the full text sequence. - - Since `SequenceGenerator.stream` calls the tokenizer at every step this - method loops over the generator returned by `sequence_generator` itself - so the tokenizer is called only once after all token ids have been - generated. - - Parameters - ---------- - prompts - A string or list of strings that are passed to the model before - generating the first token. - max_tokens - An integer representing maximum number of tokens that will be generated - (per prompt) - stop_at - A string or list of strings at which the text generated will stop - rng - The random number generator. Defaults to a non-seeded `torch.Generator` - instance. - - Returns - ------- - The generation(s), potentially cast to another type. - """ - import torch - - if isinstance(prompts, str): - prompts = [prompts] - - if isinstance(stop_at, str): - stop_at = [stop_at] - - stop_sequences = stop_at - num_samples = self.num_samples - - if rng is None: - rng = torch.Generator(device=self.device) - rng.seed() - - prompt_token_ids, attention_masks = self.tokenizer.encode(prompts) - prompt_token_ids = prompt_token_ids.to(self.device) - attention_masks = attention_masks.to(self.device) - - # To draw multiple samples we repeat the prompt as many times - # as there are samples. We copy the FSMs and initialize the - # FSM states. - num_samples = self.num_samples - batch_size = len(prompts) - - prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0) - attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0) - fsm_states = [0 for _ in range(batch_size * num_samples)] - fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)] - weights = torch.zeros( - (batch_size * num_samples), dtype=torch.float, device=self.device - ) - - states = sequence_generator( - self.model, - self.sampler, - fsms, - prompt_token_ids, - weights, - attention_masks, - fsm_states, - rng=rng, - ) - - while True: - try: - last_state = next(states) - if max_tokens or stop_sequences: - token_ids = last_state.token_ids - generated_token_ids = self.get_generated_token_ids( - prompt_token_ids, token_ids - ) - if max_tokens and len(generated_token_ids[0]) >= max_tokens: - break - if stop_sequences and self.is_stop_sequence_found( - self.tokenizer.decode(generated_token_ids), stop_sequences - ): - break - except StopIteration: - break - - token_ids = last_state.token_ids - generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids) + self, + prompts: Optional[Union[str, List[str]]], + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + rng: Optional["torch.Generator"] = None, + include_prompt: bool = False, + ) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]: + """Generate the full text sequence. + + Since `SequenceGenerator.stream` calls the tokenizer at every step this + method loops over the generator returned by `sequence_generator` itself + so the tokenizer is called only once after all token ids have been + generated. + + Parameters + ---------- + prompts + A string or list of strings that are passed to the model before + generating the first token. + max_tokens + An integer representing maximum number of tokens that will be generated + (per prompt) + stop_at + A string or list of strings at which the text generated will stop + rng + The random number generator. Defaults to a non-seeded `torch.Generator` + instance. + include prompt + Whether the prompt itself should be included in the generation pattern. Also + has the effect of the prompt being in the output. Useful if the input is + the start of a json dict, for example, and you want to generate the rest. + + Returns + ------- + The generation(s), potentially cast to another type. + """ + import torch + + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(stop_at, str): + stop_at = [stop_at] + + stop_sequences = stop_at + num_samples = self.num_samples + + if rng is None: + rng = torch.Generator(device=self.device) + rng.seed() + + prompt_token_ids, attention_masks = self.tokenizer.encode(prompts) + rich.print(f"Input: [green]{self.tokenizer.decode(prompt_token_ids)[0]}") + + prompt_token_ids = prompt_token_ids.to(self.device) + attention_masks = attention_masks.to(self.device) + + # To draw multiple samples we repeat the prompt as many times + # as there are samples. We copy the FSMs and initialize the + # FSM states. + num_samples = self.num_samples + batch_size = len(prompts) + + prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0) + attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0) + fsm_states = [0 for _ in range(batch_size * num_samples)] + fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)] + weights = torch.zeros( + (batch_size * num_samples), dtype=torch.float, device=self.device + ) + + if include_prompt: + final_fsm_states = [] + for fsm, state, prompt_token_id_seq, attention_mask_seq in mit.zip_equal( + fsms, fsm_states, prompt_token_ids, attention_masks): + assert state == 0, f"Expected state to be 0, got {state}" + assert isinstance(state, int), f"Expected state to be an int, got {state}" + for token_id, attention_mask in zip(prompt_token_id_seq, attention_mask_seq): + # print(f"{token_id = }, {attention_mask = }, {state = }") + # print(f"\"{self.tokenizer.decode([token_id]) = }\"") + state = fsm.get_next_state(state=state, token_id=token_id.item()) + if state < 0: + raise ValueError(f"Invalid state {state}") + final_fsm_states.append(state) + + start_len = len(fsm_states) + fsm_states = final_fsm_states + assert len(fsm_states) == start_len, f"{len(fsm_states) = }, {start_len = }" + + states = sequence_generator( + self.model, + self.sampler, + fsms, + prompt_token_ids, + weights, + attention_masks, + fsm_states, + rng=rng, + ) + + while True: + try: + last_state = next(states) + if max_tokens or stop_sequences: + token_ids = last_state.token_ids + generated_token_ids = self.get_generated_token_ids( + prompt_token_ids, token_ids + ) + + if max_tokens and len(generated_token_ids[0]) >= max_tokens: + break + + if stop_sequences and self.is_stop_sequence_found( + self.tokenizer.decode(generated_token_ids), stop_sequences + ): + break + except StopIteration: + break + + token_ids = last_state.token_ids + + + # The generated token ids are the same as the token ids if we include the prompt. + # To retain the ability to parse to json, we need to keep the prompt token ids. + if include_prompt: + generated_token_ids = token_ids + else: + generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids) + + generated = self.tokenizer.decode(generated_token_ids) + stripped = [ + self.strip_stop_sequences(sequence, stop_sequences) + for sequence in generated + ] + formatted = [self.format_sequence(sequence) for sequence in stripped] + rich.print(f"Output: [green]{formatted[0]}") + + # We reshape the output to (batch_size, sample_size) + output: List[List[FormattedOutput]] = list() + for i in range(0, batch_size * num_samples, num_samples): + output.append(formatted[i : i + num_samples]) + + # We remove leading dimensions for the output + if batch_size == 1 and num_samples == 1: + return output[0][0] + elif batch_size == 1: + return output[0] + elif num_samples == 1: + return [samples[0] for samples in output] + else: + return output - generated = self.tokenizer.decode(generated_token_ids) - stripped = [ - self.strip_stop_sequences(sequence, stop_sequences) - for sequence in generated - ] - formatted = [self.format_sequence(sequence) for sequence in stripped] - - # We reshape the output to (batch_size, sample_size) - output: List[List[FormattedOutput]] = list() - for i in range(0, batch_size * num_samples, num_samples): - output.append(formatted[i : i + num_samples]) - - # We remove leading dimensions for the output - if batch_size == 1 and num_samples == 1: - return output[0][0] - elif batch_size == 1: - return output[0] - elif num_samples == 1: - return [samples[0] for samples in output] - else: - return output def stream( self, From 305891fe064eb51625230f6c0b7f775b90589457 Mon Sep 17 00:00:00 2001 From: Jules Gagnon-Marchand Date: Tue, 27 Aug 2024 18:46:35 -0700 Subject: [PATCH 2/3] whitespace fix --- outlines/generate/api.py | 288 +++++++++++++++++++-------------------- 1 file changed, 144 insertions(+), 144 deletions(-) diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 03c6eb782..335b6d4a7 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -134,150 +134,150 @@ def __call__( rng: Optional["torch.Generator"] = None, include_prompt: bool = False, ) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]: - """Generate the full text sequence. - - Since `SequenceGenerator.stream` calls the tokenizer at every step this - method loops over the generator returned by `sequence_generator` itself - so the tokenizer is called only once after all token ids have been - generated. - - Parameters - ---------- - prompts - A string or list of strings that are passed to the model before - generating the first token. - max_tokens - An integer representing maximum number of tokens that will be generated - (per prompt) - stop_at - A string or list of strings at which the text generated will stop - rng - The random number generator. Defaults to a non-seeded `torch.Generator` - instance. - include prompt - Whether the prompt itself should be included in the generation pattern. Also - has the effect of the prompt being in the output. Useful if the input is - the start of a json dict, for example, and you want to generate the rest. - - Returns - ------- - The generation(s), potentially cast to another type. - """ - import torch - - if isinstance(prompts, str): - prompts = [prompts] - - if isinstance(stop_at, str): - stop_at = [stop_at] - - stop_sequences = stop_at - num_samples = self.num_samples - - if rng is None: - rng = torch.Generator(device=self.device) - rng.seed() - - prompt_token_ids, attention_masks = self.tokenizer.encode(prompts) - rich.print(f"Input: [green]{self.tokenizer.decode(prompt_token_ids)[0]}") - - prompt_token_ids = prompt_token_ids.to(self.device) - attention_masks = attention_masks.to(self.device) - - # To draw multiple samples we repeat the prompt as many times - # as there are samples. We copy the FSMs and initialize the - # FSM states. - num_samples = self.num_samples - batch_size = len(prompts) - - prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0) - attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0) - fsm_states = [0 for _ in range(batch_size * num_samples)] - fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)] - weights = torch.zeros( - (batch_size * num_samples), dtype=torch.float, device=self.device - ) - - if include_prompt: - final_fsm_states = [] - for fsm, state, prompt_token_id_seq, attention_mask_seq in mit.zip_equal( - fsms, fsm_states, prompt_token_ids, attention_masks): - assert state == 0, f"Expected state to be 0, got {state}" - assert isinstance(state, int), f"Expected state to be an int, got {state}" - for token_id, attention_mask in zip(prompt_token_id_seq, attention_mask_seq): - # print(f"{token_id = }, {attention_mask = }, {state = }") - # print(f"\"{self.tokenizer.decode([token_id]) = }\"") - state = fsm.get_next_state(state=state, token_id=token_id.item()) - if state < 0: - raise ValueError(f"Invalid state {state}") - final_fsm_states.append(state) - - start_len = len(fsm_states) - fsm_states = final_fsm_states - assert len(fsm_states) == start_len, f"{len(fsm_states) = }, {start_len = }" - - states = sequence_generator( - self.model, - self.sampler, - fsms, - prompt_token_ids, - weights, - attention_masks, - fsm_states, - rng=rng, - ) - - while True: - try: - last_state = next(states) - if max_tokens or stop_sequences: - token_ids = last_state.token_ids - generated_token_ids = self.get_generated_token_ids( - prompt_token_ids, token_ids - ) - - if max_tokens and len(generated_token_ids[0]) >= max_tokens: - break - - if stop_sequences and self.is_stop_sequence_found( - self.tokenizer.decode(generated_token_ids), stop_sequences - ): - break - except StopIteration: - break - - token_ids = last_state.token_ids - - - # The generated token ids are the same as the token ids if we include the prompt. - # To retain the ability to parse to json, we need to keep the prompt token ids. - if include_prompt: - generated_token_ids = token_ids - else: - generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids) - - generated = self.tokenizer.decode(generated_token_ids) - stripped = [ - self.strip_stop_sequences(sequence, stop_sequences) - for sequence in generated - ] - formatted = [self.format_sequence(sequence) for sequence in stripped] - rich.print(f"Output: [green]{formatted[0]}") - - # We reshape the output to (batch_size, sample_size) - output: List[List[FormattedOutput]] = list() - for i in range(0, batch_size * num_samples, num_samples): - output.append(formatted[i : i + num_samples]) - - # We remove leading dimensions for the output - if batch_size == 1 and num_samples == 1: - return output[0][0] - elif batch_size == 1: - return output[0] - elif num_samples == 1: - return [samples[0] for samples in output] - else: - return output + """Generate the full text sequence. + + Since `SequenceGenerator.stream` calls the tokenizer at every step this + method loops over the generator returned by `sequence_generator` itself + so the tokenizer is called only once after all token ids have been + generated. + + Parameters + ---------- + prompts + A string or list of strings that are passed to the model before + generating the first token. + max_tokens + An integer representing maximum number of tokens that will be generated + (per prompt) + stop_at + A string or list of strings at which the text generated will stop + rng + The random number generator. Defaults to a non-seeded `torch.Generator` + instance. + include prompt + Whether the prompt itself should be included in the generation pattern. Also + has the effect of the prompt being in the output. Useful if the input is + the start of a json dict, for example, and you want to generate the rest. + + Returns + ------- + The generation(s), potentially cast to another type. + """ + import torch + + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(stop_at, str): + stop_at = [stop_at] + + stop_sequences = stop_at + num_samples = self.num_samples + + if rng is None: + rng = torch.Generator(device=self.device) + rng.seed() + + prompt_token_ids, attention_masks = self.tokenizer.encode(prompts) + rich.print(f"Input: [green]{self.tokenizer.decode(prompt_token_ids)[0]}") + + prompt_token_ids = prompt_token_ids.to(self.device) + attention_masks = attention_masks.to(self.device) + + # To draw multiple samples we repeat the prompt as many times + # as there are samples. We copy the FSMs and initialize the + # FSM states. + num_samples = self.num_samples + batch_size = len(prompts) + + prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0) + attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0) + fsm_states = [0 for _ in range(batch_size * num_samples)] + fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)] + weights = torch.zeros( + (batch_size * num_samples), dtype=torch.float, device=self.device + ) + + if include_prompt: + final_fsm_states = [] + for fsm, state, prompt_token_id_seq, attention_mask_seq in mit.zip_equal( + fsms, fsm_states, prompt_token_ids, attention_masks): + assert state == 0, f"Expected state to be 0, got {state}" + assert isinstance(state, int), f"Expected state to be an int, got {state}" + for token_id, attention_mask in zip(prompt_token_id_seq, attention_mask_seq): + # print(f"{token_id = }, {attention_mask = }, {state = }") + # print(f"\"{self.tokenizer.decode([token_id]) = }\"") + state = fsm.get_next_state(state=state, token_id=token_id.item()) + if state < 0: + raise ValueError(f"Invalid state {state}") + final_fsm_states.append(state) + + start_len = len(fsm_states) + fsm_states = final_fsm_states + assert len(fsm_states) == start_len, f"{len(fsm_states) = }, {start_len = }" + + states = sequence_generator( + self.model, + self.sampler, + fsms, + prompt_token_ids, + weights, + attention_masks, + fsm_states, + rng=rng, + ) + + while True: + try: + last_state = next(states) + if max_tokens or stop_sequences: + token_ids = last_state.token_ids + generated_token_ids = self.get_generated_token_ids( + prompt_token_ids, token_ids + ) + + if max_tokens and len(generated_token_ids[0]) >= max_tokens: + break + + if stop_sequences and self.is_stop_sequence_found( + self.tokenizer.decode(generated_token_ids), stop_sequences + ): + break + except StopIteration: + break + + token_ids = last_state.token_ids + + + # The generated token ids are the same as the token ids if we include the prompt. + # To retain the ability to parse to json, we need to keep the prompt token ids. + if include_prompt: + generated_token_ids = token_ids + else: + generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids) + + generated = self.tokenizer.decode(generated_token_ids) + stripped = [ + self.strip_stop_sequences(sequence, stop_sequences) + for sequence in generated + ] + formatted = [self.format_sequence(sequence) for sequence in stripped] + rich.print(f"Output: [green]{formatted[0]}") + + # We reshape the output to (batch_size, sample_size) + output: List[List[FormattedOutput]] = list() + for i in range(0, batch_size * num_samples, num_samples): + output.append(formatted[i : i + num_samples]) + + # We remove leading dimensions for the output + if batch_size == 1 and num_samples == 1: + return output[0][0] + elif batch_size == 1: + return output[0] + elif num_samples == 1: + return [samples[0] for samples in output] + else: + return output def stream( From 80470c0d64bb52787f81d626c7690ac4867fec69 Mon Sep 17 00:00:00 2001 From: Jules Gagnon-Marchand Date: Tue, 27 Aug 2024 18:47:45 -0700 Subject: [PATCH 3/3] removed comments and printing --- outlines/generate/api.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 335b6d4a7..52c51f137 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -179,7 +179,6 @@ def __call__( rng.seed() prompt_token_ids, attention_masks = self.tokenizer.encode(prompts) - rich.print(f"Input: [green]{self.tokenizer.decode(prompt_token_ids)[0]}") prompt_token_ids = prompt_token_ids.to(self.device) attention_masks = attention_masks.to(self.device) @@ -205,8 +204,6 @@ def __call__( assert state == 0, f"Expected state to be 0, got {state}" assert isinstance(state, int), f"Expected state to be an int, got {state}" for token_id, attention_mask in zip(prompt_token_id_seq, attention_mask_seq): - # print(f"{token_id = }, {attention_mask = }, {state = }") - # print(f"\"{self.tokenizer.decode([token_id]) = }\"") state = fsm.get_next_state(state=state, token_id=token_id.item()) if state < 0: raise ValueError(f"Invalid state {state}") @@ -262,7 +259,6 @@ def __call__( for sequence in generated ] formatted = [self.format_sequence(sequence) for sequence in stripped] - rich.print(f"Output: [green]{formatted[0]}") # We reshape the output to (batch_size, sample_size) output: List[List[FormattedOutput]] = list()