diff --git a/outlines/generate/api.py b/outlines/generate/api.py index ad01377c0..52c51f137 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -125,14 +125,15 @@ def format_sequence(self, sequence: str) -> FormattedOutput: """ return sequence - + def __call__( - self, - prompts: Union[str, List[str]], - max_tokens: Optional[int] = None, - stop_at: Optional[Union[str, List[str]]] = None, - rng: Optional["torch.Generator"] = None, - ) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]: + self, + prompts: Optional[Union[str, List[str]]], + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + rng: Optional["torch.Generator"] = None, + include_prompt: bool = False, + ) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]: """Generate the full text sequence. Since `SequenceGenerator.stream` calls the tokenizer at every step this @@ -153,6 +154,10 @@ def __call__( rng The random number generator. Defaults to a non-seeded `torch.Generator` instance. + include prompt + Whether the prompt itself should be included in the generation pattern. Also + has the effect of the prompt being in the output. Useful if the input is + the start of a json dict, for example, and you want to generate the rest. Returns ------- @@ -174,6 +179,7 @@ def __call__( rng.seed() prompt_token_ids, attention_masks = self.tokenizer.encode(prompts) + prompt_token_ids = prompt_token_ids.to(self.device) attention_masks = attention_masks.to(self.device) @@ -191,6 +197,22 @@ def __call__( (batch_size * num_samples), dtype=torch.float, device=self.device ) + if include_prompt: + final_fsm_states = [] + for fsm, state, prompt_token_id_seq, attention_mask_seq in mit.zip_equal( + fsms, fsm_states, prompt_token_ids, attention_masks): + assert state == 0, f"Expected state to be 0, got {state}" + assert isinstance(state, int), f"Expected state to be an int, got {state}" + for token_id, attention_mask in zip(prompt_token_id_seq, attention_mask_seq): + state = fsm.get_next_state(state=state, token_id=token_id.item()) + if state < 0: + raise ValueError(f"Invalid state {state}") + final_fsm_states.append(state) + + start_len = len(fsm_states) + fsm_states = final_fsm_states + assert len(fsm_states) == start_len, f"{len(fsm_states) = }, {start_len = }" + states = sequence_generator( self.model, self.sampler, @@ -210,8 +232,10 @@ def __call__( generated_token_ids = self.get_generated_token_ids( prompt_token_ids, token_ids ) + if max_tokens and len(generated_token_ids[0]) >= max_tokens: break + if stop_sequences and self.is_stop_sequence_found( self.tokenizer.decode(generated_token_ids), stop_sequences ): @@ -220,7 +244,14 @@ def __call__( break token_ids = last_state.token_ids - generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids) + + + # The generated token ids are the same as the token ids if we include the prompt. + # To retain the ability to parse to json, we need to keep the prompt token ids. + if include_prompt: + generated_token_ids = token_ids + else: + generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids) generated = self.tokenizer.decode(generated_token_ids) stripped = [ @@ -244,6 +275,7 @@ def __call__( else: return output + def stream( self, prompts: Union[str, List[str]],