Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Including the Prompt in the Pattern in SequenceGenerator #1122

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 40 additions & 8 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,15 @@ def format_sequence(self, sequence: str) -> FormattedOutput:

"""
return sequence

def __call__(
self,
prompts: Union[str, List[str]],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
rng: Optional["torch.Generator"] = None,
) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]:
self,
prompts: Optional[Union[str, List[str]]],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
rng: Optional["torch.Generator"] = None,
include_prompt: bool = False,
) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]:
"""Generate the full text sequence.

Since `SequenceGenerator.stream` calls the tokenizer at every step this
Expand All @@ -153,6 +154,10 @@ def __call__(
rng
The random number generator. Defaults to a non-seeded `torch.Generator`
instance.
include prompt
Whether the prompt itself should be included in the generation pattern. Also
has the effect of the prompt being in the output. Useful if the input is
the start of a json dict, for example, and you want to generate the rest.

Returns
-------
Expand All @@ -174,6 +179,7 @@ def __call__(
rng.seed()

prompt_token_ids, attention_masks = self.tokenizer.encode(prompts)

prompt_token_ids = prompt_token_ids.to(self.device)
attention_masks = attention_masks.to(self.device)

Expand All @@ -191,6 +197,22 @@ def __call__(
(batch_size * num_samples), dtype=torch.float, device=self.device
)

if include_prompt:
final_fsm_states = []
for fsm, state, prompt_token_id_seq, attention_mask_seq in mit.zip_equal(
fsms, fsm_states, prompt_token_ids, attention_masks):
assert state == 0, f"Expected state to be 0, got {state}"
assert isinstance(state, int), f"Expected state to be an int, got {state}"
for token_id, attention_mask in zip(prompt_token_id_seq, attention_mask_seq):
state = fsm.get_next_state(state=state, token_id=token_id.item())
if state < 0:
raise ValueError(f"Invalid state {state}")
final_fsm_states.append(state)
Copy link
Contributor

@lapp0 lapp0 Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to implement this as a Guide? I've started work on #531 and I'd like to see Guides become more composable.

Also interegular has functionality built in to make this easy.

>>> import interegular
>>> fsm = interegular.parse_pattern("abcd").to_fsm()
>>> fsm.accepts("abcd")
True
>>> partial_fsm = fsm.derive("ab")
>>> partial_fsm.accepts("abcd")
False
>>> partial_fsm.accepts("cd")
True


start_len = len(fsm_states)
fsm_states = final_fsm_states
assert len(fsm_states) == start_len, f"{len(fsm_states) = }, {start_len = }"

states = sequence_generator(
self.model,
self.sampler,
Expand All @@ -210,8 +232,10 @@ def __call__(
generated_token_ids = self.get_generated_token_ids(
prompt_token_ids, token_ids
)

if max_tokens and len(generated_token_ids[0]) >= max_tokens:
break

if stop_sequences and self.is_stop_sequence_found(
self.tokenizer.decode(generated_token_ids), stop_sequences
):
Expand All @@ -220,7 +244,14 @@ def __call__(
break

token_ids = last_state.token_ids
generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids)


# The generated token ids are the same as the token ids if we include the prompt.
# To retain the ability to parse to json, we need to keep the prompt token ids.
if include_prompt:
generated_token_ids = token_ids
else:
generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids)

generated = self.tokenizer.decode(generated_token_ids)
stripped = [
Expand All @@ -244,6 +275,7 @@ def __call__(
else:
return output


def stream(
self,
prompts: Union[str, List[str]],
Expand Down
Loading