Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incompleteness of regex and cfg guided generation #544

Merged
merged 4 commits into from
Jan 24, 2024

Conversation

benlipkin
Copy link
Contributor

I'm opening a PR to include a few small misc FSM enhancements. These should each be minor, but all focus on improving the completeness and/or performance of the FSM classes.


In this first commit, I modify how final_state is determined. Previously, for RegexFSM, the implementation of is_final_state was return state in self.final_states. This logic is incorrect, since generation should not terminate on any accepting state; it should only terminate when an EOS has been generated (else * is reduced to ?). Here are the counts of unique outputs on an example generation of 1000 samples from the regex a*

['', 'a']
[557, 443]

Instead, now for the FSM interface, we set final_state = FSMState(-1), and in next_state, if token_id == self.eos_token_id: return self.final_state and is_final_state can now be implemented as return state == self.final_state. Here are 1000 new samples, which are now correctly geometrically distributed.

['', 'a', 'aa', 'aaa', 'aaaa', 'aaaaa', 'aaaaaa', 'aaaaaaa', 'aaaaaaaa', 'aaaaaaaaa']
[524, 192, 101,  85,  45,  26,  11,   4,   6,   6]

While making this modification, I also cleaned up the code in a few more places.

Now that stop_at is handled by SequenceGenerator (thanks to #451), StopAtTokenFSM is only ever used to stop at EOS. I've made this explicit. Next, with both of the other FSMs using this new implementation, I brought CFGFSM on board as well (which also reduces some messiness and brings it closer to the rest of the interface) and moved this implementation to the abstract protocol level. This also allowed us to get rid of some hard-coded -1 states floating around in the code by just defining it once at the interface.

@lapp0
Copy link
Contributor

lapp0 commented Jan 17, 2024

Nice!

One thing to be mindful of as you develop this PR, make sure the eos_token_id is in allowed_token_ids() for states which can legally be considered final.

tokenizer = TransformerTokenizer("gpt2")
rfsm = RegexFSM("a*", tokenizer)
print("state_to_token_maps:", rfsm.states_to_token_maps)
print("finals:", rfsm.final_states)
print("allowed(0):", rfsm.allowed_token_ids(0))
print("allowed(1):", rfsm.allowed_token_ids(1))
print("eos_token_id:", tokenizer.eos_token_id)

e.g. on main this prints

state_to_token_maps: {0: {46071: 1, 64: 1, 24794: 1, 50256: 0, 7252: 1}, 1: {46071: 1, 64: 1, 50256: 1, 24794: 1, 7252: 1}}
finals: frozenset({0, 1, -1})
allowed(0): [46071, 64, 24794, 50256, 7252]
allowed(1): [46071, 64, 50256, 24794, 7252]
eos_token_id: 50256

@benlipkin
Copy link
Contributor Author

Thanks @lapp0 ! Confirmed eos_token_id is still in allowed_token_ids whenever in an accepting state. Difference now is that it is allowed to be sampled, instead of being forced to terminate. Here's replicated outputs from above on current branch.

In [1]:
tokenizer = TransformerTokenizer("gpt2")
rfsm = RegexFSM("a*", tokenizer)
print("state_to_token_maps:", rfsm.states_to_token_maps)
print("allowed(0):", rfsm.allowed_token_ids(0))
print("allowed(1):", rfsm.allowed_token_ids(1))
print("eos_token_id:", tokenizer.eos_token_id)

Out [1]:
state_to_token_maps: {0: {46071: 0, 7252: 0, 64: 0, 50256: 0, 24794: 0}}
allowed(0): [46071, 7252, 64, 50256, 24794]
allowed(1): [50256]
eos_token_id: 50256

In [2]:
inv_vocab = {v:k for k,v in tokenizer.vocabulary.items()}
for tok in sorted(rfsm.allowed_token_ids(0)):
    print(tok,inv_vocab[tok])

Out [2]:
64 a
7252 aa
24794 aaaa
46071 aaa
50256 <|endoftext|>

@benlipkin
Copy link
Contributor Author

Okay, here's the second enhancement. I had raised in #391 that the current CFG implementation was also incomplete because, in a setting where it could either extend the current terminal or start the next terminal, it would always start the next terminal. This would preclude important strings, e.g., multi-token variable names, etc. This has now been fixed by proposing both tokens that would extend the current terminal as well as tokens that would terminate the current terminal and start the next terminal. The LM can now sample to select which path to take. Both FSMs (tracking progress separately for each terminal) are stored and can then only the correct one is transitioned (and any stale ones removed) once we see what token the LM generates.

For example, for the following grammar:

grammar = r"""
    start: s
    s: /a+/
    | "(" s ")"
"""

previously, only the following strings were possible:

a
(a)
((a))
(((a)))
...

It would never be possible for the innermost string to extend as it would always be matched after a single character and only next terminals allowed.

Now, strings such as those in this next example set are supported:

(aaaaa)
(((aaa)))
((aa))
aaaa
...

I've updated documentation and added a new test case, now covering the example that used to xfail, and have also added new examples to examples/cfg.py including a dyck-like grammar similar to the one above, as well as JSON.

@benlipkin benlipkin changed the title [WIP] Misc FSM Enhancements Fix incompleteness of regex and cfg guided generation Jan 17, 2024
@benlipkin
Copy link
Contributor Author

benlipkin commented Jan 17, 2024

Okay, this is it for now; ready for review.

I was also thinking of memoizing the RegexFSM objects constructed by CFGFSM in the class since these are often shared across a batch and over multiple rounds of generation. But, with the new caching of the state to token maps during RegexFSM construction, the actual time saved here is only ~10% in my initial profiling and it increases memory usage and code complexity. So, leaving it off for now. But, maybe worth thinking about more in the future.

Another optimization that should definitely happen eventually (and will have a greater performance impact) is doing some smarter state tracking / caching for the incremental parsing. Currently (https://github.com/benlipkin/outlines/blob/fsm-enhancements/outlines/fsm/fsm.py#L270-L272) we just feed the whole prefix again each time we go back to the parser to propose a regex to enter a new accepting state. This duplicates a bunch of computation for both the lexer and the parser that repeats for the entirety of the prefix that has already been committed (up through the last closed token). This can be ameliorated with some clever caching. I know there's some other performance improvements also baked into the PartialLark implementation at https://github.com/outlines-dev/outlines/blob/main/outlines/fsm/parsing.py, so this is probably worth transitioning over to that while making these adjustments, but I haven't worked through that code yet.

Let me know any comments on the other two enhancements here.

Thanks!

Copy link
Contributor

@lapp0 lapp0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, LGTM.

@lapp0
Copy link
Contributor

lapp0 commented Jan 18, 2024

IMHO, the other enhancements should be a separate PR.

I did some work on parser caching in https://github.com/vllm-project/vllm/pull/2105/files#diff-2d38048e543f22b75ca22a9c5f45292a7f56bb4ad1a3e6c5f606a1a583b64707R270

Effectively, it caches the push-down automata based on the parsers stack.

Happy to collaborate to introduce parser performance improvements if you're interested. I'm working on some CFG benchmarks right now which should help #549

@benlipkin
Copy link
Contributor Author

Thanks @lapp0 !

Sounds good re keeping this PR to completeness-focused enhancements and tackling performance-focused optimizations in a different PR.

Yes, I would be interested in collaborating on this (thanks for sharing your vllm commit). Awesome to see some more benchmarking/profiling happening in #549

@lapp0
Copy link
Contributor

lapp0 commented Jan 21, 2024

I'm writing a few CFGFSM benchmark tests and none of my valid example json and python files could be generated. After merging this branch into my branch they could be generated.

@rlouf could you review when you get a chance?

@rlouf
Copy link
Member

rlouf commented Jan 24, 2024

Sorry for the time it took me to review the PR, I needed to make sure I understood the changes. This looks good to me, thank you for opening a PR!

@rlouf rlouf added structured generation Linked to structured generation enhancement labels Jan 24, 2024
@rlouf rlouf merged commit 46dc706 into dottxt-ai:main Jan 24, 2024
5 checks passed
@lapp0 lapp0 mentioned this pull request Jan 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement structured generation Linked to structured generation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants