diff --git a/benchmarks/bench_cfg_guide.py b/benchmarks/bench_cfg_guide.py new file mode 100644 index 000000000..8f6de914a --- /dev/null +++ b/benchmarks/bench_cfg_guide.py @@ -0,0 +1,61 @@ +import random + +from transformers import AutoTokenizer + +import outlines.grammars +from outlines.caching import cache_disabled +from outlines.fsm.guide import CFGGuide +from outlines.models.transformers import TransformerTokenizer + +from .common import ensure_numba_compiled + +random.seed(42) + + +def get_tiny_tokenizer(): + """1000 tokens in vocabulary""" + return TransformerTokenizer( + AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + ) + + +benched_grammars = { + "json": outlines.grammars.json, + "arithmetic": outlines.grammars.arithmetic, +} + + +class CFGGuideBenchmark: + params = benched_grammars.keys() + + def setup(self, grammar_name): + self.tokenizer = get_tiny_tokenizer() + ensure_numba_compiled( + self.tokenizer + ) # numba not currently used, but will be in the future + self.prebuilt_cfg_guide = CFGGuide( + benched_grammars[grammar_name], self.tokenizer + ) + + @staticmethod + def _run_random_cfg(guide): + state = guide.initial_state + token_ids = list(guide.tokenizer.vocabulary.values()) + for i in range(40): + # simulate ordering of logits top prob to lowest prob + random.shuffle(token_ids) + # simulate sampling and state update + next_token_id = next(guide.iter_valid_token_ids(state, token_ids)) + state = guide.get_next_state(state, next_token_id) + + @cache_disabled() + def time_cfg_guide_setup(self, grammar_name): + CFGGuide(benched_grammars[grammar_name], self.tokenizer) + + @cache_disabled() + def time_cfg_guide_run(self, grammar): + self._run_random_cfg(self.prebuilt_cfg_guide) + + @cache_disabled() + def peakmem_cfg_guide_run(self, grammar): + self._run_random_cfg(self.prebuilt_cfg_guide) diff --git a/docs/reference/creating_grammars.md b/docs/reference/creating_grammars.md new file mode 100644 index 000000000..78e41282a --- /dev/null +++ b/docs/reference/creating_grammars.md @@ -0,0 +1,99 @@ +# Overview + +Outlines allows the use of [Lark](https://github.com/lark-parser/lark) grammars to guide generation. These grammars are used to construct parsers that filter out incompatible tokens during the generation process The result is a generation that adheres to the grammar's production rules. + +# Primer on Creating Grammars + +To create grammars for Outlines, a solid understanding of Lark grammars is necessary. Here's how you can get started: + +- Read Lark's grammars documentations [here](https://lark-parser.readthedocs.io/en/latest/grammar.html). +- Review Outlines' existing grammars [here](/outlines/grammars). + + +# Compatibility With Outlines + +It's important to note that not all Lark grammars work with Outlines. Changes may be necessary to ensure compatability. + +### LALR(1) Parser + +Outlines utilizes Larks LALR(1) parser, meaning the grammar must be unambiguous at least up to the next token (one token lookahead). Read Lark's official LALR(1) parser documentation [here](https://lark-parser.readthedocs.io/en/stable/parsers.html#lalr-1). + +If your grammar is ambiguous, you will recieve the following error at runtime: + +``` +GrammarError: Reduce/Reduce collision in Terminal('B') between the following rules: +``` + +### Regex Terminal Restrictions + +Outlines converts terminals to finite state machines using the [Interegular](https://github.com/MegaIng/interegular/) library. Not all regular expressions work with Interegular, mitigation is described in the subsections which follow. + + +#### Avoid Lookarounds + +Examples of removing lookaround while maintaining the same functionality + +##### Example: Escaped String + +From Outlines' modified `ESCAPED_STRING` in [common.lark](/outlines/grammars/common.lark). + +Before: +``` +_STRING_INNER: /.*?/ +_STRING_ESC_INNER: _STRING_INNER /(?) +print(result) +``` + +# Converting +There are a few tools available for converting from other grammars to lark. These tools serve as a starting point. However, you will typically need to make additional adjustments to ensure full compatibility and proper functioning within Outlines. + +Tools: +- Larks built in "Nearley-to-Lark" converter https://lark-parser.readthedocs.io/en/latest/tools.html +- Convert ANTLR4 to Lark (Note, most antlr4 grammars are not LALR(1) compatible, so will require additional tweaking) https://github.com/kaby76/Domemtech.Trash/blob/main/src/trconvert/readme.md +- Extract EBNF from Yacc files https://www.bottlecaps.de/rr/ui + +Reference Grammars: +- Github Lark Grammars https://github.com/search?q=path%3A*.lark&type=code +- Github Nearley Grammars https://github.com/search?q=path%3A*.ne+%22-%3E%22&type=code +- Antlr4 grammars https://github.com/antlr/grammars-v4/ +- Grammar zoo https://slebok.github.io/zoo/index.html#html diff --git a/mkdocs.yml b/mkdocs.yml index e5f3417ac..ba3b088a9 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -126,6 +126,7 @@ nav: - JSON (function calling): reference/json.md - JSON mode: reference/json_mode.md - Grammar: reference/cfg.md + - Creating Grammars: reference/creating_grammars.md - Custom FSM operations: reference/custom_fsm_ops.md - Utilities: - Serve with vLLM: reference/serve/vllm.md diff --git a/outlines/fsm/fsm.py b/outlines/fsm/fsm.py index 4a7fce8c9..bfcf55c03 100644 --- a/outlines/fsm/fsm.py +++ b/outlines/fsm/fsm.py @@ -1,7 +1,7 @@ import warnings from typing import TYPE_CHECKING, Iterable, NewType, Optional -from outlines.fsm.guide import CFGGuide, RegexGuide, StopAtEOSGuide +from outlines.fsm.guide import RegexGuide, StopAtEOSGuide if TYPE_CHECKING: from outlines.models.tokenizer import Tokenizer @@ -45,25 +45,3 @@ def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: def next_state(self, state: FSMState, token_id: int) -> FSMState: return FSMState(self.get_next_state(state, token_id)) - - -class CFGFSM(CFGGuide): - """FSM to generate text that is in the language of a context-free grammar.""" - - def __init__(self, cfg_string: str, tokenizer): - warnings.warn( - UserWarning( - "The `CFGFSM` interface is deprecated and will be removed on 2024-06-01. Please use `CFGGuide` instead." - ) - ) - super().__init__(cfg_string, tokenizer) - - def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: - return self.get_next_instruction(state).tokens - - def next_state(self, state: FSMState, token_id: int) -> FSMState: - return FSMState(self.get_next_state(state, token_id)) - - def copy(self) -> "CFGFSM": - """Create a copy of the FSM.""" - return CFGFSM(self.cfg_string, self.tokenizer) diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index 2e4415148..aa073d107 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -1,8 +1,12 @@ +import collections +import copy from dataclasses import dataclass from typing import ( TYPE_CHECKING, + Any, Callable, Dict, + Generator, List, Optional, Protocol, @@ -13,10 +17,12 @@ import interegular import torch -from lark import Lark +from lark.indenter import DedentError +from lark.lexer import UnexpectedCharacters, UnexpectedToken from outlines import grammars from outlines.caching import cache +from outlines.fsm.parsing import PartialLark, PartialParserState from outlines.fsm.regex import ( create_fsm_index_tokenizer, make_byte_level_fsm, @@ -69,13 +75,15 @@ class Guide(Protocol): """ - def get_next_instruction(self, state: int) -> Instruction: + initial_state: Any + + def get_next_instruction(self, state: Any) -> Instruction: ... - def get_next_state(self, state: int, token_id: int) -> int: + def get_next_state(self, state: Any, token_id: int) -> Any: ... - def is_final_state(self, state: int) -> bool: + def is_final_state(self, state: Any) -> bool: ... def copy(self) -> "Guide": @@ -86,7 +94,8 @@ class StopAtEOSGuide(Guide): """Guide to generate tokens until the EOS token has been generated.""" final_state = 1 - start_state = 0 + start_state = 0 # TODO: remove start_state, use only initial_state + initial_state = 0 def __init__(self, tokenizer: "Tokenizer"): """Initialize the generation guide. @@ -107,7 +116,7 @@ def get_next_state(self, state: int, token_id: int) -> int: if token_id == self.eos_token_id or state == self.final_state: return self.final_state - return self.start_state + return self.initial_state def is_final_state(self, state: int): return state == self.final_state @@ -300,178 +309,175 @@ def copy(self): return self +CFGState = collections.namedtuple("CFGState", ["parser_state", "prev_token"]) + + class CFGGuide(Guide): - """Guide to generate text that is in the language of a context-free grammar.""" + """Guide to generate text that is in the language of a context-free Lark grammar.""" def __init__(self, cfg_string: str, tokenizer): + """ + Construct the PartialLark parser and set the empty initial_state (PartialParserState) + """ self.cfg_string = cfg_string self.tokenizer = tokenizer - - self.parser = Lark( + self.eos_token_id = self.tokenizer.eos_token_id + self.parser = PartialLark( cfg_string, parser="lalr", - lexer="contextual", - propagate_positions=False, - maybe_placeholders=False, - regex=True, import_paths=[grammars.GRAMMAR_PATH], ) - self.terminal_regexps = dict() - for terminal in self.parser.terminals: - if terminal.pattern is not None: - self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp() - self.terminal_regexps["$END"] = tokenizer.eos_token - - self.generation = "" - self.reset_state = False - self.allow_eos = False - self.regex_fsm: RegexGuide - - self.check_last = False - self.proposal_last: List[int] = [] - self.regex_fsm_last: RegexGuide - - self.start_state = 0 - self.final_state = -1 - - def get_next_instruction(self, state: int) -> Instruction: - """Generate an instruction for the next step. - - Upon initialization, the CFG incremental parser is used to determine the - first regex and construct the first FSM to generate the first terminal. + self.initial_state = CFGState( + parser_state=self.parser.parse(""), prev_token=None + ) - This FSM is used for proposals until either: + def get_next_instruction(self, state: CFGState) -> Instruction: + """Return the next instruction for guided generation. - - The FSM is exhausted, and its only remaining option is the EOS token, - in which case we feed the generated terminal to the - CFG incremental parser and allow it to propose the next regex - corresponding to the next set of valid terminals. - - The current FSM can be exhausted, but the EOS token is not the only - remaining option. In this case we allow proposal of current terminal - extensions, store the current FSM and its state, then also use the CFG - parser to propose a new regex corresponding to terminating the current - terminal and starting the next one. The model can then sample from - either of these sets to determine whether to extend the current - terminal or terminate it and start the next one. + Current lazy approach: + - For each token in the vocabulary + - create a copy of the parsers state + - add the tokens to the parsers input text + - if valid, add token to returned tokens - The CFG incremental parser is allowed to propose the EOS token from any accepting state, - and once it is generated, the FSM will continue to always generate the EOS token. + Further refinements are necessary for performant text processing. Parameters ---------- state - The current state of the FSM. + The guides current PartialParserState, or None if complete Returns ------- - A list that contains the tokens to mask. + A `Generate` instance that contains the model and the allowed token ids. """ - if self.is_final_state(state): - return Write([self.tokenizer.eos_token_id]) - proposal: List[int] = [] - if self.generation != "": - if self.check_last: - proposer = self.regex_fsm_last - else: - proposer = self.regex_fsm + if state.parser_state is None: + return Write(torch.tensor([self.eos_token_id])) - instruction = proposer.get_next_instruction(state) + valid_tokens = list( + self.iter_valid_token_ids(state, self.tokenizer.vocabulary.values()) + ) + if len(valid_tokens) == 1: + return Write(torch.tensor(valid_tokens)) + return Generate(torch.tensor(valid_tokens)) - assert instruction.tokens is not None + def iter_valid_token_ids( + self, state: CFGState, candidate_token_ids: list + ) -> Generator[int, None, None]: + """ + Iterate over the given token_ids and yield those that are valid for the current parser state. - if isinstance(instruction, Write): - proposal += instruction.tokens + Parameters + ---------- + parser_state + The current state of the parser, or None if complete. + token_ids + The list of token ids to check for validity. + + Yields + ------ + int + Valid token ids. + """ + if state.parser_state is None: + yield self.eos_token_id + return + + for token_id in candidate_token_ids: + if token_id == self.eos_token_id: + if self.can_terminate_state(state): + yield token_id else: - proposal += instruction.tokens - - if self.tokenizer.eos_token_id not in proposal: - return Generate(proposal) - - self.check_last = False - proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] - if len(proposal) > 0: - self.check_last = True - self.proposal_last = proposal.copy() - self.regex_fsm_last = proposer - - interactive = self.parser.parse_interactive(self.generation) - interactive.exhaust_lexer() - - options = {self.terminal_regexps[x] for x in interactive.accepts()} - # add %ignore terminals - options |= {self.terminal_regexps[x] for x in self.parser.lexer_conf.ignore} - - if self.terminal_regexps["$END"] in options: - options.remove(self.terminal_regexps["$END"]) - if len(options) == 0: - return Write([self.tokenizer.eos_token_id]) - self.allow_eos = True - options.add("") - assert len(options) > 1 - - regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")" - self.regex_fsm = RegexGuide(regex_string, self.tokenizer) - self.reset_state = True - - instruction = self.regex_fsm.get_next_instruction(self.start_state) - - assert instruction.tokens is not None - - if isinstance(instruction, Write): - proposal += instruction.tokens - else: - proposal += instruction.tokens - - if self.allow_eos: - self.allow_eos = False - else: - proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] - assert len(proposal) > 0 - - return Generate(proposal) - - def get_next_state(self, state: int, token_id: int) -> int: - """Update the state of the guide. - - Transitions the underlying regex FSM to its next state. - If at max tokens or EOS token, transition permanently to the final state. - Update stored partial generations for subsequent incremental parsing. + try: + self._get_parser_state_token_applied(state, int(token_id)) + yield token_id + except ( + ValueError, + EOFError, + UnexpectedToken, + UnexpectedCharacters, + DedentError, + ): + pass + + def get_next_state(self, state: CFGState, token_id: int) -> CFGState: + """ + Update the state of the guide. + Decode the token_id, and calculate the new parser_state with the token applied. Parameters ---------- state - The current state of the FSM. + The guides current PartialParserState, or None if complete token_id The id of the token that was just generated. Returns ------- - The new state of the FSM. - """ + The guides new PartialParserState - # We need to return the final state when in the final state because we - # then generate EOS tokens instead of stopping the generation. - if token_id == self.tokenizer.eos_token_id or state == self.final_state: - return self.final_state - - self.generation += self.tokenizer.decode([token_id])[0] + """ + if state.parser_state is None or token_id == self.eos_token_id: + parser_state = None + else: + parser_state = self._get_parser_state_token_applied(state, int(token_id)) + return CFGState(parser_state=parser_state, prev_token=token_id) - if self.check_last: - if token_id in self.proposal_last: - return self.regex_fsm_last.get_next_state(state, token_id) - self.check_last = False + def _get_parser_state_token_applied( + self, state: CFGState, token_id: int + ) -> PartialParserState: + """ + Don't mutate `parser_state`, copy to protect - if self.reset_state: - self.reset_state = False - state = self.start_state + Get the token string + - if first token in generation: tokenizer.decode (no leading whitespace) + - else: normalized (with possibly leading whitespace) - return self.regex_fsm.get_next_state(state, token_id) + Don't allow empty ("") tokens, raise ValueError + """ + parser_state = copy.copy(state.parser_state) # prevent side effects - def is_final_state(self, state: int) -> bool: - return state == self.final_state + # normalize + if state.prev_token is None: + new_token_str = self.tokenizer.decode([token_id])[0] + else: + prev_token_str = self.tokenizer.decode([[state.prev_token]])[0] + combined_token_str = self.tokenizer.decode([[state.prev_token, token_id]])[ + 0 + ] + new_token_str = combined_token_str[len(prev_token_str) :] + + if new_token_str == "": + raise ValueError("empty next token") + + # update parser with new token + parser_state.lexer.state.text += new_token_str + self.parser.parse_from_state(parser_state, is_end=False) + + return parser_state + + def is_final_state(self, state: CFGState) -> bool: + # TODO: remove this method, use can_terminate_state and must_terminate_state + # here and in RegexGuide per https://github.com/outlines-dev/outlines/issues/885 + return self.can_terminate_state(state) + + def can_terminate_state(self, state: CFGState) -> bool: + """Generation is allowed to terminate""" + if state.parser_state is not None: + try: + copy.copy(state.parser_state).feed_eof() + except UnexpectedToken: + return False + return True + + def must_terminate_state(self, state: CFGState) -> bool: + """Generation must terminate, no legal continuations""" + return state.parser_state is None or set(state.parser_state.accepts()).issubset( + {"$END"} + ) def copy(self) -> "CFGGuide": - """Create a copy of the FSM.""" + """Create a copy of the Guide.""" return CFGGuide(self.cfg_string, self.tokenizer) diff --git a/outlines/fsm/parsing.py b/outlines/fsm/parsing.py index e4fa7b764..f780fb46e 100644 --- a/outlines/fsm/parsing.py +++ b/outlines/fsm/parsing.py @@ -447,6 +447,49 @@ def feed_token_no_stack(self, token, is_end=False): if is_end and state_stack[-1] == end_state: return + def feed_eof(self): + last_token = self.lexer.state.last_token + + if last_token is None: + eof_token = self.lexer._Token("$END", "", 0, 1, 1) + else: + eof_token = Token.new_borrow_pos("$END", "", last_token) + + new_token_is_legal = ( + last_token is None + or last_token.type != "partial" + or any(ti.is_final for ti in last_token.value.terminals_and_info) + ) + if new_token_is_legal: + self.feed_token(eof_token, is_end=True) + else: + raise UnexpectedToken(eof_token, [], state=self, interactive_parser=None) + + def choices(self): + return self.parse_conf.parse_table.states[self.position] + + def accepts(self): + """ + Adapted from https://github.com/lark-parser/lark/blob/be542c2ff6d968817df019b8bf03f37b3111c08c/lark/parsers/lalr_interactive_parser.py#L95 + Returns the set of possible tokens that will advance the parser into a new valid state. + """ + accepts = set() + conf_no_callbacks = copy(self.parse_conf) + # We don't want to call callbacks here since those might have arbitrary side effects + # and are unnecessarily slow. + conf_no_callbacks.callbacks = {} + for t in self.choices(): + if t.isupper(): # is terminal? + new_state = copy(self) + new_state.parse_conf = conf_no_callbacks + try: + new_state.feed_token(new_state.lexer._Token(t, "")) + except UnexpectedToken: + pass + else: + accepts.add(t) + return accepts + def __copy__(self): return type(self)( self.parse_conf, @@ -483,12 +526,7 @@ def parse_from_state(self, state, last_token=None, is_end=False): state.feed_token(token) if is_end and (not token or token.type != "partial"): - end_token = ( - Token.new_borrow_pos("$END", "", token) - if token - else Token("$END", "", 0, 1, 1) - ) - state.feed_token(end_token, True) + state.feed_eof() return state except UnexpectedInput as e: @@ -614,6 +652,8 @@ def __init__(self, conf: "LexerConf", states, always_accept=()): lexer_conf.terminals = [ terminals_by_name[n] for n in accepts if n in terminals_by_name ] + if not lexer_conf.terminals: + continue lexer = PartialBasicLexer(lexer_conf) lexer_by_symbols[key] = lexer @@ -626,9 +666,22 @@ def lex(self, lexer_state: LexerState, parser_state: Any) -> Iterator[Token]: try: while True: lexer = self.lexers[parser_state.position] - yield lexer.next_token(lexer_state, parser_state) + next_tok = lexer.next_token(lexer_state, parser_state) + yield next_tok except EOFError: pass + except KeyError: + if len(lexer_state.text) > lexer_state.line_ctr.char_pos: + raise UnexpectedCharacters( + lexer_state.text, + lexer_state.line_ctr.char_pos, + lexer_state.line_ctr.line, + lexer_state.line_ctr.column, + allowed=False, + token_history=lexer_state.last_token and [lexer_state.last_token], + state=parser_state, + terminals_by_name=self.root_lexer.terminals, + ) class PartialBasicLexer(BasicLexer): diff --git a/outlines/generate/cfg.py b/outlines/generate/cfg.py index 0df833067..4f372f209 100644 --- a/outlines/generate/cfg.py +++ b/outlines/generate/cfg.py @@ -1,7 +1,10 @@ from functools import singledispatch -from outlines.generate.api import SequenceGeneratorAdapter -from outlines.models import OpenAI +from outlines.generate.api import ( + SequenceGeneratorAdapter, + VisionSequenceGeneratorAdapter, +) +from outlines.models import ExLlamaV2Model, LlamaCpp, OpenAI, TransformersVision from outlines.samplers import Sampler, multinomial @@ -14,8 +17,7 @@ def cfg( Arguments --------- model: - An instance of `Transformer` that represents a model from the - `transformers` library. + An `outlines.model` instance. sampler: The sampling algorithm to use to generate token ids from the logits distribution. @@ -25,13 +27,32 @@ def cfg( A `SequenceGeneratorAdapter` instance that generates text. """ + from outlines.processors import CFGLogitsProcessor + + logits_processor = CFGLogitsProcessor(cfg_str, tokenizer=model.tokenizer) + return SequenceGeneratorAdapter(model, logits_processor, sampler) + + +@cfg.register(TransformersVision) +def cfg_vision(model, cfg_str: str, sampler: Sampler = multinomial()): + from outlines.processors import CFGLogitsProcessor + + logits_processor = CFGLogitsProcessor(cfg_str, tokenizer=model.tokenizer) + return VisionSequenceGeneratorAdapter(model, logits_processor, sampler) + + +@cfg.register(ExLlamaV2Model) +def cfg_exllamav2(model, cfg_str: str, sampler: Sampler = multinomial()): raise NotImplementedError( - f"The CFG Logits processor is not available for {type(model)}. " - + "Please subscribe to https://github.com/outlines-dev/outlines/issues/684" - + " for updates on the fix." + "Not yet available, track progress in https://github.com/outlines-dev/outlines/pull/1010" ) +@cfg.register(LlamaCpp) +def cfg_llamacpp(model, cfg_str: str, sampler: Sampler = multinomial()): + raise NotImplementedError("Not yet available due to bug in llama_cpp tokenizer") + + @cfg.register(OpenAI) def cfg_openai(model, cfg_str: str, sampler: Sampler = multinomial()): raise NotImplementedError( diff --git a/outlines/generate/fsm.py b/outlines/generate/fsm.py index 03fe512b9..a9338836a 100644 --- a/outlines/generate/fsm.py +++ b/outlines/generate/fsm.py @@ -16,19 +16,19 @@ def fsm( model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial() ) -> SequenceGeneratorAdapter: - from outlines.processors import FSMLogitsProcessor + from outlines.processors import GuideLogitsProcessor - fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer) - logits_processor = FSMLogitsProcessor(tokenizer=model.tokenizer, fsm=fsm) + guide = RegexGuide.from_interegular_fsm(fsm, model.tokenizer) + logits_processor = GuideLogitsProcessor(tokenizer=model.tokenizer, guide=guide) return SequenceGeneratorAdapter(model, logits_processor, sampler) @fsm.register(TransformersVision) def fsm_vision(model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()): - from outlines.processors import FSMLogitsProcessor + from outlines.processors import GuideLogitsProcessor - fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer) - logits_processor = FSMLogitsProcessor(tokenizer=model.tokenizer, fsm=fsm) + guide = RegexGuide.from_interegular_fsm(fsm, model.tokenizer) + logits_processor = GuideLogitsProcessor(tokenizer=model.tokenizer, guide=guide) return VisionSequenceGeneratorAdapter(model, logits_processor, sampler) diff --git a/outlines/grammars/common.lark b/outlines/grammars/common.lark index 801c27e97..ee5e00c50 100644 --- a/outlines/grammars/common.lark +++ b/outlines/grammars/common.lark @@ -43,11 +43,14 @@ SIGNED_FLOAT: ["+"|"-"] FLOAT NUMBER: FLOAT | INT SIGNED_NUMBER: ["+"|"-"] NUMBER -// -// TODO: Working escaped_string -// UNESCAPED_STRING: /\"[^"]*\"/ +// based on `outlines/fsm/json_schema.py` +_NON_CONTROL_CHAR: /([^"\\\x00-\x1F\x7F-\x9F])/ +_ESCAPED_CHAR: /\\/ (_NON_CONTROL_CHAR | /\\/ | /"/) +ESCAPED_STRING_INNER: _NON_CONTROL_CHAR | _ESCAPED_CHAR +ESCAPED_STRING: /"/ ESCAPED_STRING_INNER* /"/ + // diff --git a/outlines/grammars/json.lark b/outlines/grammars/json.lark index 72af448ce..7429fa558 100644 --- a/outlines/grammars/json.lark +++ b/outlines/grammars/json.lark @@ -2,7 +2,7 @@ ?value: object | array -| UNESCAPED_STRING +| ESCAPED_STRING | SIGNED_NUMBER -> number | "true" -> true | "false" -> false @@ -10,9 +10,9 @@ array : "[" [value ("," value)*] "]" object : "{" [pair ("," pair)*] "}" -pair : UNESCAPED_STRING ":" value +pair : ESCAPED_STRING ":" value -%import common.UNESCAPED_STRING +%import common.ESCAPED_STRING %import common.SIGNED_NUMBER %import common.WS diff --git a/outlines/processors/__init__.py b/outlines/processors/__init__.py index 22c10d905..f0f0f829b 100644 --- a/outlines/processors/__init__.py +++ b/outlines/processors/__init__.py @@ -1,6 +1,6 @@ from .structured import ( CFGLogitsProcessor, - FSMLogitsProcessor, + GuideLogitsProcessor, JSONLogitsProcessor, OutlinesLogitsProcessor, RegexLogitsProcessor, diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py index 0966a90db..6e68ad70b 100644 --- a/outlines/processors/structured.py +++ b/outlines/processors/structured.py @@ -24,7 +24,7 @@ limitations under the License. """ import math -from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union import torch from pydantic import BaseModel @@ -39,36 +39,41 @@ from outlines.models.tokenizer import Tokenizer -class FSMLogitsProcessor(OutlinesLogitsProcessor): - """Bias generation using a finite state machine. +class GuideLogitsProcessor(OutlinesLogitsProcessor): + """Bias generation using a finite Attributes ---------- tokenizer The tokenizer used to convert tokens to ids. - fsm - The finite state machine which is used to bias the logits. + guide + The `outlines.fsm.Guide` which is used to bias the logits. """ - def __init__(self, tokenizer: "Tokenizer", fsm: Guide): - """A FSM-based logits processor. + tokenizer: "Tokenizer" + guide: Guide + _guide_states: Dict[int, Any] + _seq_start_idx: Optional[int] + + def __init__(self, tokenizer: "Tokenizer", guide: Guide): + """A Guide-based logits processor. Parameters ---------- tokenizer The tokenizer used to convert tokens to ids. - fsm - The finite state machine which is used to bias the logits. + guide + The `outlines.fsm.Guide. which is used to bias the logits. """ self.tokenizer = tokenizer - self._fsm_states: Dict[int, int] = {hash(tuple([])): 0} - self.fsm: Guide = fsm - self._seq_start_idx: Optional[int] = None + self.guide = guide + self._guide_states = {hash(tuple([])): self.guide.initial_state} + self._seq_start_idx = None def process_logits( self, input_ids: List[List[int]], logits: torch.Tensor ) -> torch.Tensor: - """Use the FSM to bias the logits before sampling the next token. + """Use the Guide to bias the logits before sampling the next token. Parameters ---------- @@ -91,38 +96,38 @@ def process_logits( gen_ids = seq_ids[self._seq_start_idx :] curr_state_key = hash(tuple(gen_ids)) - if curr_state_key not in self._fsm_states: - prev_state = self._fsm_states[hash(tuple(gen_ids[:-1]))] - curr_state = self.fsm.get_next_state(prev_state, gen_ids[-1]) - self._fsm_states[curr_state_key] = curr_state + if curr_state_key not in self._guide_states: + prev_state = self._guide_states[hash(tuple(gen_ids[:-1]))] + curr_state = self.guide.get_next_state(prev_state, gen_ids[-1]) + self._guide_states[curr_state_key] = curr_state - sequence_states.append(self._fsm_states[curr_state_key]) + sequence_states.append(self._guide_states[curr_state_key]) mask = torch.full_like(logits, -math.inf) - for i, fsm_state in enumerate(sequence_states): - allowed_tokens = self.fsm.get_next_instruction(fsm_state).tokens + for i, guide_state in enumerate(sequence_states): + allowed_tokens = self.guide.get_next_instruction(guide_state).tokens mask[i, allowed_tokens] = logits[i, allowed_tokens] return mask - def copy(self) -> "FSMLogitsProcessor": + def copy(self) -> "GuideLogitsProcessor": """Return a copy of the logits processor.""" - return FSMLogitsProcessor(tokenizer=self.tokenizer, fsm=self.fsm.copy()) + return GuideLogitsProcessor(tokenizer=self.tokenizer, guide=self.guide.copy()) -class RegexLogitsProcessor(FSMLogitsProcessor): +class RegexLogitsProcessor(GuideLogitsProcessor): """Bias generation based on a regular expression. Attributes ---------- tokenizer The tokenizer used to convert tokens to ids. - fsm - The finite state machine which is used to bias the logits. + guide + The `outlines.fsm.RegexGuide. which is used to bias the logits. """ def __init__(self, regex_string: str, tokenizer: "Tokenizer"): - """Compile the FSM that drives the regex-guided generation. + """Compile the RegexGuide that drives the regex-guided generation. Parameters ---------- @@ -131,8 +136,8 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"): tokenizer An Outlines tokenizer """ - fsm = RegexGuide(regex_string, tokenizer) - super().__init__(tokenizer=tokenizer, fsm=fsm) + guide = RegexGuide(regex_string, tokenizer) + super().__init__(tokenizer=tokenizer, guide=guide) class JSONLogitsProcessor(RegexLogitsProcessor): @@ -142,8 +147,8 @@ class JSONLogitsProcessor(RegexLogitsProcessor): ---------- tokenizer The tokenizer used to convert tokens to ids. - fsm - The finite state machine which is used to bias the logits. + guide + The `outlines.fsm.RegexGuide. which is used to bias the logits. """ def __init__( @@ -152,7 +157,7 @@ def __init__( tokenizer: "Tokenizer", whitespace_pattern: Optional[str] = None, ): - """Compile the FSM that drives the JSON-guided generation. + """Compile the Guide that drives the JSON-guided generation. Parameters ---------- @@ -170,19 +175,21 @@ def __init__( super().__init__(regex_string=regex_string, tokenizer=tokenizer) -class CFGLogitsProcessor(FSMLogitsProcessor): +class CFGLogitsProcessor(GuideLogitsProcessor): """Bias generation based on a context-free grammar. Attributes ---------- tokenizer The tokenizer used to convert tokens to ids. - fsm - The finite state machine which is used to bias the logits. + guide + The `outlines.fsm.CFGGuide. which is used to bias the logits. """ + guide: CFGGuide + def __init__(self, cfg_str: str, tokenizer: "Tokenizer"): - """Compile the FSM that drives the CFG-guided generation. + """Compile the CFGGuide that drives the CFG-guided generation. Parameters ---------- @@ -191,5 +198,36 @@ def __init__(self, cfg_str: str, tokenizer: "Tokenizer"): tokenizer The tokenizer used to convert tokens to ids. """ - cfg_automata = CFGGuide(cfg_string=cfg_str, tokenizer=tokenizer) - super().__init__(tokenizer=tokenizer, fsm=cfg_automata) + cfg_guide = CFGGuide(cfg_string=cfg_str, tokenizer=tokenizer) + super().__init__(tokenizer=tokenizer, guide=cfg_guide) + + def process_logits( + self, input_ids: List[List[int]], logits: torch.Tensor + ) -> torch.Tensor: + """Same behavior as GuideLogitsProcessor, but uses rejection sampling""" + if self._seq_start_idx is None: + self._seq_start_idx = len(input_ids[0]) + + sequence_states: List = [] # vector of states corresponding to `input_ids` + + for seq_ids in input_ids: + gen_ids = seq_ids[self._seq_start_idx :] + curr_state_key = hash(tuple(gen_ids)) + + if curr_state_key not in self._guide_states: + prev_state = self._guide_states[hash(tuple(gen_ids[:-1]))] + curr_state = self.guide.get_next_state(prev_state, gen_ids[-1]) + self._guide_states[curr_state_key] = curr_state + + sequence_states.append(self._guide_states[curr_state_key]) + + mask = torch.full_like(logits, -math.inf) + for i, guide_state in enumerate(sequence_states): + first_legal_token = next( + self.guide.iter_valid_token_ids( + guide_state, torch.argsort(logits[i], descending=True) + ) + ) + mask[i, [first_legal_token]] = logits[i, [first_legal_token]] + + return mask diff --git a/tests/cfg_samples/arithmetic/lots_of_ops.arithmetic.test b/tests/cfg_samples/arithmetic/lots_of_ops.arithmetic.test new file mode 100644 index 000000000..1489aebc7 --- /dev/null +++ b/tests/cfg_samples/arithmetic/lots_of_ops.arithmetic.test @@ -0,0 +1 @@ +5+1+1+1+1+1+1+1+1+1+1+1+1+1+1+1+1+1 diff --git a/tests/cfg_samples/arithmetic/simple_math.arithmetic.test b/tests/cfg_samples/arithmetic/simple_math.arithmetic.test new file mode 100644 index 000000000..882f05c8d --- /dev/null +++ b/tests/cfg_samples/arithmetic/simple_math.arithmetic.test @@ -0,0 +1 @@ +(1 * 2) - (0.1 * 2 * 9.42) diff --git a/tests/cfg_samples/json/outlines.generate.samplers.mypy.json.test b/tests/cfg_samples/json/outlines.generate.samplers.mypy.json.test new file mode 100644 index 000000000..1a328a9b6 --- /dev/null +++ b/tests/cfg_samples/json/outlines.generate.samplers.mypy.json.test @@ -0,0 +1,372 @@ +{ + ".class": "MypyFile", + "_fullname": "outlines.generate.samplers", + "future_import_flags": [], + "is_partial_stub_package": false, + "is_stub": false, + "names": { + ".class": "SymbolTable", + "Protocol": { + ".class": "SymbolTableNode", + "cross_ref": "typing.Protocol", + "kind": "Gdef" + }, + "Sampler": { + ".class": "SymbolTableNode", + "kind": "Gdef", + "node": { + ".class": "TypeInfo", + "_promote": [], + "abstract_attributes": [ + [ + "__call__", + 2 + ] + ], + "alt_promote": null, + "bases": [ + "builtins.object" + ], + "dataclass_transform_spec": null, + "declared_metaclass": null, + "defn": { + ".class": "ClassDef", + "fullname": "outlines.generate.samplers.Sampler", + "name": "Sampler", + "type_vars": [] + }, + "deletable_attributes": [], + "flags": [ + "is_abstract", + "is_protocol" + ], + "fullname": "outlines.generate.samplers.Sampler", + "has_param_spec_type": false, + "metaclass_type": "abc.ABCMeta", + "metadata": {}, + "module_name": "outlines.generate.samplers", + "mro": [ + "outlines.generate.samplers.Sampler", + "builtins.object" + ], + "names": { + ".class": "SymbolTable", + "__call__": { + ".class": "SymbolTableNode", + "kind": "Mdef", + "node": { + ".class": "FuncDef", + "abstract_status": 2, + "arg_kinds": [ + 0, + 0, + 0, + 0 + ], + "arg_names": [ + "self", + "logits", + "samples", + "rng" + ], + "dataclass_transform_spec": null, + "flags": [ + "is_trivial_body" + ], + "fullname": "outlines.generate.samplers.Sampler.__call__", + "name": "__call__", + "type": { + ".class": "CallableType", + "arg_kinds": [ + 0, + 0, + 0, + 0 + ], + "arg_names": [ + "self", + "logits", + "samples", + "rng" + ], + "arg_types": [ + "outlines.generate.samplers.Sampler", + { + ".class": "AnyType", + "missing_import_name": "outlines.generate.samplers.torch", + "source_any": null, + "type_of_any": 3 + }, + "builtins.int", + { + ".class": "AnyType", + "missing_import_name": "outlines.generate.samplers.torch", + "source_any": null, + "type_of_any": 3 + } + ], + "bound_args": [], + "def_extras": { + "first_arg": "self" + }, + "fallback": "builtins.function", + "from_concatenate": false, + "implicit": false, + "is_ellipsis_args": false, + "name": "__call__ of Sampler", + "ret_type": { + ".class": "AnyType", + "missing_import_name": "outlines.generate.samplers.torch", + "source_any": null, + "type_of_any": 3 + }, + "type_guard": null, + "unpack_kwargs": false, + "variables": [] + } + } + } + }, + "self_type": null, + "slots": null, + "tuple_type": null, + "type_vars": [], + "typeddict_type": null + } + }, + "__annotations__": { + ".class": "SymbolTableNode", + "kind": "Gdef", + "node": { + ".class": "Var", + "flags": [ + "is_ready" + ], + "fullname": "outlines.generate.samplers.__annotations__", + "name": "__annotations__", + "type": { + ".class": "Instance", + "args": [ + "builtins.str", + { + ".class": "AnyType", + "missing_import_name": null, + "source_any": null, + "type_of_any": 6 + } + ], + "type_ref": "builtins.dict" + } + } + }, + "__doc__": { + ".class": "SymbolTableNode", + "kind": "Gdef", + "node": { + ".class": "Var", + "flags": [ + "is_ready" + ], + "fullname": "outlines.generate.samplers.__doc__", + "name": "__doc__", + "type": "builtins.str" + } + }, + "__file__": { + ".class": "SymbolTableNode", + "kind": "Gdef", + "node": { + ".class": "Var", + "flags": [ + "is_ready" + ], + "fullname": "outlines.generate.samplers.__file__", + "name": "__file__", + "type": "builtins.str" + } + }, + "__name__": { + ".class": "SymbolTableNode", + "kind": "Gdef", + "node": { + ".class": "Var", + "flags": [ + "is_ready" + ], + "fullname": "outlines.generate.samplers.__name__", + "name": "__name__", + "type": "builtins.str" + } + }, + "__package__": { + ".class": "SymbolTableNode", + "kind": "Gdef", + "node": { + ".class": "Var", + "flags": [ + "is_ready" + ], + "fullname": "outlines.generate.samplers.__package__", + "name": "__package__", + "type": "builtins.str" + } + }, + "greedy": { + ".class": "SymbolTableNode", + "kind": "Gdef", + "node": { + ".class": "FuncDef", + "abstract_status": 0, + "arg_kinds": [ + 0, + 0, + 2 + ], + "arg_names": [ + "logits", + "samples", + "_" + ], + "dataclass_transform_spec": null, + "flags": [], + "fullname": "outlines.generate.samplers.greedy", + "name": "greedy", + "type": { + ".class": "CallableType", + "arg_kinds": [ + 0, + 0, + 2 + ], + "arg_names": [ + "logits", + "samples", + "_" + ], + "arg_types": [ + { + ".class": "AnyType", + "missing_import_name": "outlines.generate.samplers.torch", + "source_any": null, + "type_of_any": 3 + }, + "builtins.int", + { + ".class": "AnyType", + "missing_import_name": null, + "source_any": null, + "type_of_any": 1 + } + ], + "bound_args": [], + "def_extras": { + "first_arg": null + }, + "fallback": "builtins.function", + "from_concatenate": false, + "implicit": false, + "is_ellipsis_args": false, + "name": "greedy", + "ret_type": { + ".class": "AnyType", + "missing_import_name": "outlines.generate.samplers.torch", + "source_any": null, + "type_of_any": 3 + }, + "type_guard": null, + "unpack_kwargs": false, + "variables": [] + } + } + }, + "multinomial": { + ".class": "SymbolTableNode", + "kind": "Gdef", + "node": { + ".class": "FuncDef", + "abstract_status": 0, + "arg_kinds": [ + 0, + 0, + 0 + ], + "arg_names": [ + "logits", + "samples", + "rng" + ], + "dataclass_transform_spec": null, + "flags": [], + "fullname": "outlines.generate.samplers.multinomial", + "name": "multinomial", + "type": { + ".class": "CallableType", + "arg_kinds": [ + 0, + 0, + 0 + ], + "arg_names": [ + "logits", + "samples", + "rng" + ], + "arg_types": [ + { + ".class": "AnyType", + "missing_import_name": "outlines.generate.samplers.torch", + "source_any": null, + "type_of_any": 3 + }, + "builtins.int", + { + ".class": "AnyType", + "missing_import_name": "outlines.generate.samplers.torch", + "source_any": null, + "type_of_any": 3 + } + ], + "bound_args": [], + "def_extras": { + "first_arg": null + }, + "fallback": "builtins.function", + "from_concatenate": false, + "implicit": false, + "is_ellipsis_args": false, + "name": "multinomial", + "ret_type": { + ".class": "AnyType", + "missing_import_name": "outlines.generate.samplers.torch", + "source_any": null, + "type_of_any": 3 + }, + "type_guard": null, + "unpack_kwargs": false, + "variables": [] + } + } + }, + "torch": { + ".class": "SymbolTableNode", + "kind": "Gdef", + "node": { + ".class": "Var", + "flags": [ + "is_suppressed_import", + "is_ready", + "is_inferred" + ], + "fullname": "outlines.generate.samplers.torch", + "name": "torch", + "type": { + ".class": "AnyType", + "missing_import_name": "outlines.generate.samplers.torch", + "source_any": null, + "type_of_any": 3 + } + } + } + }, + "path": "/home/andrew/p/outlines/outlines/generate/samplers.py" +} diff --git a/tests/cfg_samples/json/simple_fruit.json.test b/tests/cfg_samples/json/simple_fruit.json.test new file mode 100644 index 000000000..bffa952c7 --- /dev/null +++ b/tests/cfg_samples/json/simple_fruit.json.test @@ -0,0 +1,20 @@ +[ + { + "ID": "1", + "Name": "Andrew \"The Escaper\" Lapp", + "Age": "30", + "FavFruit": "Banana" + }, + { + "ID": "2", + "Name": "Mohammad", + "Age": "40", + "FavFruit": "\"Any Fruit As Long as It's In Quotes!\"" + }, + { + "ID": "3", + "Name": "Alice", + "Age": "61", + "FavFruit": "Peaches, but only \n newline separated peaches" + } +] diff --git a/tests/cfg_samples/json/simple_fruit_no_indent.json.test b/tests/cfg_samples/json/simple_fruit_no_indent.json.test new file mode 100644 index 000000000..9b7d319da --- /dev/null +++ b/tests/cfg_samples/json/simple_fruit_no_indent.json.test @@ -0,0 +1 @@ +[{"ID": "1", "Name": "Andrew", "Age": "30", "FavFruit": "Banana"}, {"ID": "2", "Name": "Mohammad", "Age": "40", "FavFruit": "Apple"}, {"ID": "3", "Name": "Alice", "Age": "61", "FavFruit": "Peach"}] diff --git a/tests/fsm/test_cfg_guide.py b/tests/fsm/test_cfg_guide.py new file mode 100644 index 000000000..d92afa625 --- /dev/null +++ b/tests/fsm/test_cfg_guide.py @@ -0,0 +1,457 @@ +from collections import namedtuple +from pathlib import Path + +import pytest +from transformers import AutoTokenizer + +from outlines import grammars, models +from outlines.fsm.guide import CFGGuide + + +@pytest.fixture +def cleanup_lark_import(): + import importlib + + import lark.lark + + yield + # Clean up lark.lark.LarkOptions._defaults + importlib.reload(lark.lark) + + +TestInputs = namedtuple( + "TestInputs", + [ + "grammar", # the lark grammar to validate against + "vocabulary", # the token strings which can be concatenated for a generation + "generated", # the tokens which have been generated so far + "legal_next_tokens", # the subset of the vocabulary which can legally be next in `generated` + ], +) + + +cfg_test_inputs = { + "Next Token Doesn't Complete Terminal": TestInputs( + grammar=r'?start: "a" "bc"', + vocabulary=["a", "ab", "b", "c"], + generated=["a"], + legal_next_tokens=["b"], + ), + "Ambiguous Terminal Completion": TestInputs( + grammar=r'?start: "ab" | "abc"', + vocabulary=["a", "ab", "abc", "abcd", "b", "c"], + generated=["a"], + legal_next_tokens=["b"], + ), + "Token is Substring of Another Token": TestInputs( + grammar=r'?start: "abc" | "abcd"', + vocabulary=["a", "b", "bc", "bcd", "bcde"], + generated=["a"], + legal_next_tokens=["b", "bc", "bcd"], + ), + "Multiple Valid Continuations": TestInputs( + grammar=r'?start: ("a" "b") | ("a" "c")', + vocabulary=["a", "b", "bc", "c"], + generated=["a"], + legal_next_tokens=["b", "c"], + ), + "Prefix Matches Multiple Terminals": TestInputs( + grammar=r'?start: "abcd" | "abef"', + vocabulary=["a", "b", "be", "bcd", "bef", "bed"], + generated=["a"], + legal_next_tokens=["b", "be", "bcd", "bef"], + ), + "Token Matches Multiple Paths in Grammar": TestInputs( + grammar=r'?start: ("a" "b" "c") | ("a" "b" "d")', + vocabulary=["a", "b", "c", "d"], + generated=["a", "b"], + legal_next_tokens=["c", "d"], + ), + "Incomplete Terminal at End of Prefix": TestInputs( + grammar=r'?start: "abc"', + vocabulary=["a", "ab", "c", "abc", "abcd"], + generated=["ab"], + legal_next_tokens=["c"], + ), + "Complex Grammar Rules": TestInputs( + grammar=r'?start: "a" "b" ["c"]', + vocabulary=["a", "b", "c"], + generated=["a", "b"], + legal_next_tokens=["c", None], # Allowing the document to end after "a" "b" + ), + "Empty Prefix String": TestInputs( + grammar=r'?start: "a" | "b"', + vocabulary=["a", "b", "c", "d"], + generated=[], + legal_next_tokens=["a", "b"], + ), + "Ambiguous Pattern Completion": TestInputs( + grammar=r'?start: /a+/ "b" /c?d/', + vocabulary=["a", "aa", "b", "cd", "d"], + generated=["a", "a", "b"], + legal_next_tokens=["cd", "d"], + ), + "Optional Patterns with Overlapping Tokens": TestInputs( + grammar=r'?start: "a" "b"? "c"', + vocabulary=["a", "b", "bc", "c"], + generated=["a"], + legal_next_tokens=["b", "bc", "c"], + ), + "Greedy vs. Non-Greedy Matching": TestInputs( + grammar=r'?start: /a+?/ "b" /c/', + vocabulary=["a", "aa", "aaa", "b", "c"], + generated=["a", "a", "b"], + legal_next_tokens=["c"], + ), + "Nested Optional Elements": TestInputs( + grammar=r'?start: "a" ["b" ["c"]]', + vocabulary=["a", "b", "bc", "c"], + generated=["a"], + legal_next_tokens=[ + "b", + "bc", + None, + ], # Allowing the document to end after "a" "b" + ), + "Recursive Patterns": TestInputs( + grammar=r'?start: /a(bc)*/ "d"', + vocabulary=["a", "bc", "bcbcbc", "d"], + generated=["a", "bc", "d"], + legal_next_tokens=[None], # Allowing the document to end after "a" "bc" "d" + ), + "Overlapping Character Classes": TestInputs( + grammar=r'?start: /[ab]+/ "d"', + vocabulary=["a", "b", "c", "aa", "bb", "cc", "d"], + generated=["a", "b"], + legal_next_tokens=["d", "a", "b", "aa", "bb"], + ), + "Conditional Patterns": TestInputs( + grammar=r'?start: "a" /b/ "c" (/d/)?', + vocabulary=["a", "b", "c", "d"], + generated=["a", "b", "c"], + legal_next_tokens=["d", None], # Allowing the document to end after "a" "b" "c" + ), + "Unicode and Special Characters": TestInputs( + grammar=r'?start: /[a-zA-Z]/ "é" /[0-9]+/', + vocabulary=["a", "b", "é", "1", "2", "12"], + generated=["a", "é"], + legal_next_tokens=["1", "2", "12"], + ), + "Unicode and Special Characters Are Choices": TestInputs( + grammar=r'?start: /[a-zA-Z]/ "é" /[0-9]+/', + vocabulary=["a", "b", "é", "é9", "2", "12"], + generated=["a"], + legal_next_tokens=["é", "é9"], + ), + "Whitespace and Ignored Characters": TestInputs( + grammar=r'?start: "a" / *\s*b/ "c"', + vocabulary=["a", " b", " c", "c"], + generated=["a", " b"], + legal_next_tokens=["c"], + ), + "Token Overlaps Multiple Terminals": TestInputs( + grammar=r'?start: "a" "b" "c" "ab"', + vocabulary=["a", "b", "bc", "cab", "abc"], + generated=["a"], + legal_next_tokens=["b", "bc"], + ), + "Interleaved Sequences": TestInputs( + grammar=r'?start: ("a" "b") | ("a" "c")', + vocabulary=["a", "b", "c", "ab", "ac"], + generated=["a"], + legal_next_tokens=["b", "c"], + ), + "Repeated and Nested Patterns": TestInputs( + grammar=r'?start: "a" ("b" "c")* "d"', + vocabulary=["a", "b", "c", "bc", "bcc", "cbc", "bcbc", "cbccccd", "d", "bcbcd"], + generated=["a", "b", "c"], + legal_next_tokens=["b", "bc", "bcbc", "d", "bcbcd"], + ), + "Ambiguous Ending Patterns": TestInputs( + grammar=r'?start: "a" (/b/)? (/c/)*', + vocabulary=["a", "b", "c"], + generated=["a", "b"], + legal_next_tokens=["c", None], # Allowing the document to end after "a" "b" + ), + "Whitespace Handling in Patterns": TestInputs( + grammar=r'?start: "a" / *\s*b/ /c /', + vocabulary=["a", " b", "c"], + generated=["a", " b"], + legal_next_tokens=["c"], + ), + "Token with Escape Characters": TestInputs( + grammar=r'?start: "a\n" ("\t")? "b"', + vocabulary=["a\nb", "a", "b", "\n", "\tb", "\t"], + generated=["a", "\n"], + legal_next_tokens=["b", "\t", "\tb"], + ), + "Complex Nesting": TestInputs( + grammar=r'?start: "a" ("b" ("c" "d"))', + vocabulary=["a", "b", "c", "d"], + generated=["a", "b", "c"], + legal_next_tokens=["d"], + ), + "Repeated Optional Patterns": TestInputs( + grammar=r'?start: ("a" ["b"])*', + vocabulary=["a", "b"], + generated=["a", "b", "a"], + legal_next_tokens=["a", "b", None], + ), + "Multiple Non-Terminal Symbols": TestInputs( + grammar=r""" + ?start: A B + A: "a" + B: "b" + """, + vocabulary=["a", "b"], + generated=["a"], + legal_next_tokens=["b"], + ), + "Recursive Definitions": TestInputs( + grammar=r""" + ?start: term_a + term_a: "a" term_a | "b" + """, + vocabulary=["a", "b"], + generated=["a", "a"], + legal_next_tokens=["a", "b"], + ), + "Ignored Patterns": TestInputs( + grammar=r""" + ?start: "a" "b" "c" + %ignore /\s+/ + """, + vocabulary=["a", "b", "c", " "], + generated=["a", " ", "b"], + legal_next_tokens=["c", " "], + ), + "Cross-References": TestInputs( + grammar=r""" + ?start: term_a + term_a : /a/ term_b + term_b : /b/ term_a | /c/ + """, + vocabulary=["a", "b", "c", "bac"], + generated=["a"], + legal_next_tokens=["b", "bac", "c"], + ), + "Multiple Complex Non-Terminal Rules": TestInputs( + grammar=r""" + ?start: S1 S2 S3 + S1: "a" | "b" + S2: "c" | "d" + S3: "e" "f" | "g" + """, + vocabulary=["a", "b", "c", "d", "e", "f", "g"], + generated=["a", "c"], + legal_next_tokens=["e", "g"], + ), + # + # + # TODO: fix + # parser incorrectly requires consumption of a or b in this case for unknown reasons + # "Nested Patterns with Repetition": TestInputs( + # grammar=r'?start: "a" ("b" | "c")* "d"', + # vocabulary=["a", "b", "bc", "bcc", "cbc", "bcbc", "cbccccd", "c", "d" "bcdcb"], + # generated=["a"], + # legal_next_tokens=["b", "bc", "bcc", "cbc", "bcbc", "cbccccd", "c", "d"], + # ), + # + # + # TODO: fix + # adjacent terminals with ambiguous starts and ends not handled properly + # ensure parser isn't greedy incorrectly + # "Ambiguous Overlapping Patterns": TestInputs( + # grammar=r"?start: /ab*/ /bc?/", + # vocabulary=["a", "ab", "abc", "b", "bc", "c"], + # generated=["a", "b", "b"], + # legal_next_tokens=["b", "c", "bc"], + # ), + # "Ambiguous Overlapping Patterns In Generation": TestInputs( + # grammar=r"?start: /ab*/ /bc?/", + # vocabulary=["a", "ab", "abc", "b", "bc", "c", "abbbc"], + # generated=["a", "b", "b"], + # legal_next_tokens=["b", "c", "bc"], + # ), + # + # + # SKIP: + # Awaiting negative lookarounds in interegular + # "Lookahead and Lookbehind with Nested Conditions": TestInputs( + # grammar=r'?start: /(?<=a)b(?=c)/ "d"', + # vocabulary=["a", "b", "c", "d"], + # generated=["a", "b", "c"], + # legal_next_tokens=["d"] + # ), + # "Lookbehind Patterns": TestInputs( + # grammar=r'?start: /(?<=a)b/ "c"', + # vocabulary=["a", "b", "c"], + # generated=["a", "b"], + # legal_next_tokens=["c"] + # ), +} + + +@pytest.mark.parametrize("name", cfg_test_inputs.keys()) +def test_cfg_next_token(name, cleanup_lark_import): + inputs = cfg_test_inputs[name] + + class MockTokenizer: + vocabulary = {token: i + 1 for i, token in enumerate(inputs.vocabulary)} + vocabulary[""] = 0 + reverse_vocab = {i: tok for tok, i in vocabulary.items()} + special_tokens = {""} + eos_token_id = 0 + + def convert_token_to_string(self, token): + return token + + def decode(self, token_ids): + if isinstance(token_ids[0], list): + return [ + "".join(map(self.reverse_vocab.get, token_ids_sublist)) + for token_ids_sublist in token_ids + ] + return [self.reverse_vocab[token_id] for token_id in token_ids] + + # create a guide and the appropriate state advanced + # per the inputs generated tokens + tokenizer = MockTokenizer() + guide = CFGGuide(inputs.grammar, tokenizer) + state = guide.initial_state + for token in inputs.generated: + state = guide.get_next_state(state, tokenizer.vocabulary[token]) + instruction = guide.get_next_instruction(state) + + # normalize expectations and returned tokens for simple comparison + returned_next_tokens = sorted( + {tokenizer.reverse_vocab[int(t)] for t in instruction.tokens} + ) + expected_next_tokens = sorted( + { + t + if t is not None + else tokenizer.reverse_vocab[tokenizer.eos_token_id] # None -> "" + for t in inputs.legal_next_tokens + } + ) + + assert returned_next_tokens == expected_next_tokens + + +@pytest.fixture(scope="session") +def tokenizer_sentencepiece_gpt2(): + return models.TransformerTokenizer(AutoTokenizer.from_pretrained("gpt2")) + + +@pytest.fixture(scope="session") +def tokenizer_sentencepiece_llama1(): + return models.TransformerTokenizer( + AutoTokenizer.from_pretrained( + "trl-internal-testing/tiny-random-LlamaForCausalLM" + ) + ) + + +@pytest.fixture(scope="session") +def tokenizer_tiktoken_llama3(): + return models.TransformerTokenizer( + AutoTokenizer.from_pretrained("yujiepan/llama-3-tiny-random") + ) + + +@pytest.fixture(scope="session") +def tokenizer_character_level_byt5(): + return models.TransformerTokenizer( + AutoTokenizer.from_pretrained("google/byt5-small") + ) + + +# Collects all samples within cfg_samples/ and makes adding +# a test case as easy as adding a valid sample to cfg_samples/ +all_samples = {} +examples_path = Path(__file__).parent.parent / "cfg_samples" +for sample_collection_path in examples_path.iterdir(): + grammar_name = sample_collection_path.name + grammar = getattr(grammars, grammar_name) + for sample_path in sample_collection_path.iterdir(): + test_name = f"{grammar_name}_{sample_path.name}" + with open(sample_path) as f: + all_samples[test_name] = (grammar_name, grammar, f.read().rstrip("\n")) + + +@pytest.mark.parametrize("sample_name", all_samples.keys()) +def test_cfg_test_sample_valid_with_lark(sample_name): + """assert the provided sample is valid (testing the test itself)""" + from lark import Lark, UnexpectedToken + + grammar_name, grammar_str, sample = all_samples[sample_name] + try: + parser = Lark(grammar_str, parser="lalr", import_paths=[grammars.GRAMMAR_PATH]) + parser = parser.parse_interactive(sample) + token = parser.exhaust_lexer()[-1] + parser.feed_eof(token) + except UnexpectedToken as e: + raise Exception( + f"Invalid test, sample '{sample_name}' isn't a legal generation of '{grammar_name}':\n{e}" + ) + + +@pytest.mark.parametrize("sample_name", all_samples.keys()) +@pytest.mark.parametrize( + "tokenizer_name", + [ + "tokenizer_sentencepiece_gpt2", + "tokenizer_sentencepiece_llama1", + "tokenizer_tiktoken_llama3", + "tokenizer_character_level_byt5", + ], +) +def test_cfg_grammar_sample(request, sample_name, tokenizer_name, cleanup_lark_import): + """Test whether CFG can generate the exact token sequence as tokenizer.encode(sample) produces""" + + # TODO: enable these tests once improvements are made + if ( + tokenizer_name != "tokenizer_character_level_byt5" + or sample_name == "json_outlines.generate.samplers.mypy.json.test" + ): + pytest.skip("CFG is too slow, skipping tests for this tokenizer") + elif sample_name == "arithmetic_lots_of_ops.arithmetic.test": + pytest.skip("CFG incorrectly handles this valid sample, skipping until bugfix") + + tokenizer = request.getfixturevalue(tokenizer_name) + + grammar_name, grammar_str, sample = all_samples[sample_name] + cfg_guide = CFGGuide(grammar_str, tokenizer) + + sample_token_ids = tokenizer.tokenizer.encode( + sample, add_special_tokens=False, return_tensors="pt" + )[0] + assert ( + len(sample_token_ids.shape) == 1 + ) # ensure we're encoding in the desired shape for this test + + state = cfg_guide.initial_state + for i, token_id in enumerate(sample_token_ids): + if tokenizer.decode([token_id])[0] == "": + continue + next_instruction = cfg_guide.get_next_instruction(state) + if token_id not in next_instruction.tokens: + processed_str = tokenizer.decode([sample_token_ids[:i]])[0] + remaining_str = tokenizer.decode([sample_token_ids[i:]])[0] + if next_instruction.tokens == [tokenizer.eos_token_id]: + error_label = "CFGGuide required EOS early" + else: + expected = tokenizer.decode(next_instruction.tokens) + error_label = ( + f"Mismatched expectations, Guide expected {sorted(expected)}" + ) + raise Exception( + f"{error_label}\n" + f"processed:\n```{processed_str}```\n" + f"remaining:\n```{remaining_str}```" + ) + next_instruction.tokens + state = cfg_guide.get_next_state(state, token_id) + final_instruction = cfg_guide.get_next_instruction(state) + assert tokenizer.eos_token_id in final_instruction.tokens diff --git a/tests/fsm/test_fsm.py b/tests/fsm/test_fsm.py index 21163b70d..94166fd95 100644 --- a/tests/fsm/test_fsm.py +++ b/tests/fsm/test_fsm.py @@ -1,6 +1,6 @@ import pytest -from outlines.fsm.fsm import CFGFSM, RegexFSM, StopAtEosFSM +from outlines.fsm.fsm import RegexFSM, StopAtEosFSM def assert_expected_tensor_ids(tensor, ids): @@ -90,263 +90,3 @@ def convert_token_to_string(self, token): state = fsm.next_state(state=5, token_id=103) assert fsm.is_final_state(state) - - -def test_cfg(): - class MockTokenizer: - vocabulary = {"{": 1, "}": 2, "[": 3, "]": 4, "eos": 5} - special_tokens = {"eos"} - eos_token = "eos" - eos_token_id = 5 - - def convert_token_to_string(self, token): - return token - - @property - def inverse_vocabulary(self): - return {v: k for k, v in self.vocabulary.items()} - - def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] - - cfg_str = """ - start: expr - expr: "{" expr "}" | "[" expr "]" | - """ - tokenizer = MockTokenizer() - - with pytest.warns(UserWarning): - fsm = CFGFSM(cfg_str, tokenizer) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=fsm.start_state), [1, 3, 5]) - state = fsm.next_state(state=fsm.start_state, token_id=1) - assert fsm.generation == "{" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2, 3]) - state = fsm.next_state(state=state, token_id=3) - assert fsm.generation == "{[" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 3, 4]) - state = fsm.next_state(state=state, token_id=4) - assert fsm.generation == "{[]" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [2]) - state = fsm.next_state(state=state, token_id=2) - assert fsm.generation == "{[]}" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [5]) - state = fsm.next_state(state=state, token_id=5) - assert fsm.generation == "{[]}" - assert fsm.is_final_state(state) - - -def test_cfg_early_termination(): - class MockTokenizer: - vocabulary = {"(": 1, ")": 2, "eos": 3} - special_tokens = {"eos"} - eos_token = "eos" - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - @property - def inverse_vocabulary(self): - return {v: k for k, v in self.vocabulary.items()} - - def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] - - cfg_str = """ - start: expr+ - expr: "(" subexpr ")" - subexpr: expr | - """ - tokenizer = MockTokenizer() - - with pytest.warns(UserWarning): - fsm = CFGFSM(cfg_str, tokenizer) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=fsm.start_state), [1]) - state = fsm.next_state(state=fsm.start_state, token_id=1) - assert fsm.generation == "(" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2]) - state = fsm.next_state(state=state, token_id=2) - assert fsm.generation == "()" - assert not fsm.is_final_state(state) - - # possible to continue or terminate - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 3]) - state = fsm.next_state(state=state, token_id=3) # feed eos - assert fsm.generation == "()" - assert fsm.is_final_state(state) - - # once eos generated, can only terminate - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [3]) - - -def test_cfg_ignore_directive(): - class MockTokenizer: - vocabulary = {"a": 1, " ": 2, "eos": 3} - special_tokens = {"eos"} - eos_token = "eos" - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - @property - def inverse_vocabulary(self): - return {v: k for k, v in self.vocabulary.items()} - - def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] - - cfg_str = """ - start: LETTER+ - LETTER: "a" - WS: " " - %ignore WS - """ - tokenizer = MockTokenizer() - - with pytest.warns(UserWarning): - fsm = CFGFSM(cfg_str, tokenizer) - - state = 0 - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=0), [1, 2]) - state = fsm.next_state(state=0, token_id=2) - assert fsm.generation == " " - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=0), [1, 2]) - state = fsm.next_state(state=0, token_id=1) - assert fsm.generation == " a" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2, 3]) - state = fsm.next_state(state=state, token_id=2) - assert fsm.generation == " a " - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2, 3]) - state = fsm.next_state(state=state, token_id=2) - assert fsm.generation == " a " - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2, 3]) - state = fsm.next_state(state=state, token_id=1) - assert fsm.generation == " a a" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2, 3]) - state = fsm.next_state(state=state, token_id=3) - assert fsm.generation == " a a" - assert fsm.is_final_state(state) - - # once eos generated, can only terminate - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [3]) - - -def test_cfg_multitoken_terminal(): - class MockTokenizer: - vocabulary = {"a": 1, "b": 2, "eos": 3} - special_tokens = {"eos"} - eos_token = "eos" - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - @property - def inverse_vocabulary(self): - return {v: k for k, v in self.vocabulary.items()} - - def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] - - cfg_str = """ - start: S - S: "aa" | "bb" - """ - tokenizer = MockTokenizer() - - with pytest.warns(UserWarning): - fsm = CFGFSM(cfg_str, tokenizer) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=fsm.start_state), [1, 2]) - assert fsm.reset_state # starting new regex - state = fsm.next_state(state=fsm.start_state, token_id=1) - assert fsm.generation == "a" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1]) - assert not fsm.reset_state # continuing current regex - state = fsm.next_state(state=state, token_id=1) - assert fsm.generation == "aa" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [3]) - assert not fsm.reset_state # completing current regex - state = fsm.next_state(state=state, token_id=3) - assert fsm.generation == "aa" - assert fsm.is_final_state(state) - - -def test_cfg_allow_both_extend_and_shift_terminal(): - class MockTokenizer: - vocabulary = {"(": 1, ")": 2, "a": 3, "eos": 4} - special_tokens = {"eos"} - eos_token = "eos" - eos_token_id = 4 - - def convert_token_to_string(self, token): - return token - - @property - def inverse_vocabulary(self): - return {v: k for k, v in self.vocabulary.items()} - - def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] - - cfg_str = """ - start: s - s: "(" s ")" | /a+/ - """ - tokenizer = MockTokenizer() - - with pytest.warns(UserWarning): - fsm = CFGFSM(cfg_str, tokenizer) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=fsm.start_state), [1, 3]) - state = fsm.next_state(state=fsm.start_state, token_id=1) - assert fsm.generation == "(" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 3]) - state = fsm.next_state(state=state, token_id=3) - assert fsm.generation == "(a" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [2, 3]) - state = fsm.next_state(state=state, token_id=3) - assert fsm.generation == "(aa" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [2, 3]) - state = fsm.next_state(state=state, token_id=2) - assert fsm.generation == "(aa)" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [4]) - state = fsm.next_state(state=state, token_id=4) - assert fsm.generation == "(aa)" - assert fsm.is_final_state(state) diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index 20ba75893..67b4e0dd8 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -205,49 +205,47 @@ def inverse_vocabulary(self): return {v: k for k, v in self.vocabulary.items()} def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] + if isinstance(token_ids[0], list): + return [ + "".join(map(self.inverse_vocabulary.get, token_ids_sublist)) + for token_ids_sublist in token_ids + ] + return [self.inverse_vocabulary[token_id] for token_id in token_ids] cfg_str = """ start: expr expr: "{" expr "}" | "[" expr "]" | """ tokenizer = MockTokenizer() - fsm = CFGGuide(cfg_str, tokenizer) - instruction = fsm.get_next_instruction(fsm.start_state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 3, 5]) - state = fsm.get_next_state(state=fsm.start_state, token_id=1) - assert fsm.generation == "{" - assert not fsm.is_final_state(state) + guide = CFGGuide(cfg_str, tokenizer) - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2, 3]) - state = fsm.get_next_state(state=state, token_id=3) - assert fsm.generation == "{[" - assert not fsm.is_final_state(state) + assert_expected_tensor_ids( + guide.get_next_instruction(guide.initial_state).tokens, [1, 3, 5] + ) + state = guide.get_next_state(guide.initial_state, token_id=1) + assert not guide.must_terminate_state(state) + assert not guide.can_terminate_state(state) - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 3, 4]) - state = fsm.get_next_state(state=state, token_id=4) - assert fsm.generation == "{[]" - assert not fsm.is_final_state(state) + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [1, 2, 3]) + state = guide.get_next_state(state, token_id=3) + assert not guide.must_terminate_state(state) + assert not guide.can_terminate_state(state) - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [2]) - state = fsm.get_next_state(state=state, token_id=2) - assert fsm.generation == "{[]}" - assert not fsm.is_final_state(state) + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [1, 3, 4]) + state = guide.get_next_state(state, token_id=4) + assert not guide.must_terminate_state(state) + assert not guide.can_terminate_state(state) - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Write) - assert_expected_tensor_ids(instruction.tokens, [5]) - state = fsm.get_next_state(state=state, token_id=5) - assert fsm.generation == "{[]}" - assert fsm.is_final_state(state) + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [2]) + state = guide.get_next_state(state, token_id=2) + assert guide.must_terminate_state(state) + assert guide.can_terminate_state(state) + + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [5]) + state = guide.get_next_state(state, token_id=5) + assert guide.must_terminate_state(state) + assert guide.can_terminate_state(state) def test_cfg_early_termination(): @@ -265,7 +263,12 @@ def inverse_vocabulary(self): return {v: k for k, v in self.vocabulary.items()} def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] + if isinstance(token_ids[0], list): + return [ + "".join(map(self.inverse_vocabulary.get, token_ids_sublist)) + for token_ids_sublist in token_ids + ] + return [self.inverse_vocabulary[token_id] for token_id in token_ids] cfg_str = """ start: expr+ @@ -273,34 +276,29 @@ def decode(self, token_ids): subexpr: expr | """ tokenizer = MockTokenizer() - fsm = CFGGuide(cfg_str, tokenizer) - instruction = fsm.get_next_instruction(fsm.start_state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1]) - state = fsm.get_next_state(state=fsm.start_state, token_id=1) - assert fsm.generation == "(" - assert not fsm.is_final_state(state) + guide = CFGGuide(cfg_str, tokenizer) - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2]) - state = fsm.get_next_state(state=state, token_id=2) - assert fsm.generation == "()" - assert not fsm.is_final_state(state) + assert_expected_tensor_ids( + guide.get_next_instruction(guide.initial_state).tokens, [1] + ) + state = guide.get_next_state(guide.initial_state, token_id=1) + assert not guide.must_terminate_state(state) + assert not guide.can_terminate_state(state) + + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [1, 2]) + state = guide.get_next_state(state, token_id=2) + assert not guide.must_terminate_state(state) + assert guide.can_terminate_state(state) # possible to continue or terminate - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 3]) - state = fsm.get_next_state(state=state, token_id=3) # feed eos - assert fsm.generation == "()" - assert fsm.is_final_state(state) + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [1, 3]) + state = guide.get_next_state(state, token_id=3) # feed eos + assert guide.must_terminate_state(state) + assert guide.can_terminate_state(state) # once eos generated, can only terminate - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Write) - assert_expected_tensor_ids(instruction.tokens, [3]) + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [3]) def test_cfg_ignore_directive(): @@ -318,7 +316,12 @@ def inverse_vocabulary(self): return {v: k for k, v in self.vocabulary.items()} def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] + if isinstance(token_ids[0], list): + return [ + "".join(map(self.inverse_vocabulary.get, token_ids_sublist)) + for token_ids_sublist in token_ids + ] + return [self.inverse_vocabulary[token_id] for token_id in token_ids] cfg_str = """ start: LETTER+ @@ -327,56 +330,43 @@ def decode(self, token_ids): %ignore WS """ tokenizer = MockTokenizer() - fsm = CFGGuide(cfg_str, tokenizer) - state = 0 + guide = CFGGuide(cfg_str, tokenizer) - instruction = fsm.get_next_instruction(0) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2]) - state = fsm.get_next_state(state=0, token_id=2) - assert fsm.generation == " " - assert not fsm.is_final_state(state) + state = guide.initial_state - instruction = fsm.get_next_instruction(0) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2]) - state = fsm.get_next_state(state=0, token_id=1) - assert fsm.generation == " a" - assert not fsm.is_final_state(state) + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [1, 2]) + state = guide.get_next_state(state, token_id=2) + assert not guide.must_terminate_state(state) + assert not guide.can_terminate_state(state) - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2, 3]) - state = fsm.get_next_state(state=state, token_id=2) - assert fsm.generation == " a " - assert not fsm.is_final_state(state) + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [1, 2]) + state = guide.get_next_state(state, token_id=1) + assert not guide.must_terminate_state(state) + assert guide.can_terminate_state(state) - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2, 3]) - state = fsm.get_next_state(state=state, token_id=2) - assert fsm.generation == " a " - assert not fsm.is_final_state(state) + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [1, 2, 3]) + state = guide.get_next_state(state, token_id=2) + assert not guide.must_terminate_state(state) + assert guide.can_terminate_state(state) - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2, 3]) - state = fsm.get_next_state(state=state, token_id=1) - assert fsm.generation == " a a" - assert not fsm.is_final_state(state) + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [1, 2, 3]) + state = guide.get_next_state(state, token_id=2) + assert not guide.must_terminate_state(state) + assert guide.can_terminate_state(state) - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2, 3]) - state = fsm.get_next_state(state=state, token_id=3) - assert fsm.generation == " a a" - assert fsm.is_final_state(state) + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [1, 2, 3]) + state = guide.get_next_state(state, token_id=1) + assert not guide.must_terminate_state(state) + assert guide.can_terminate_state(state) + + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [1, 2, 3]) + state = guide.get_next_state(state, token_id=3) + assert guide.must_terminate_state(state) + assert guide.can_terminate_state(state) # once eos generated, can only terminate - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Write) - assert_expected_tensor_ids(instruction.tokens, [3]) + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [3]) def test_cfg_multitoken_terminal(): @@ -394,38 +384,37 @@ def inverse_vocabulary(self): return {v: k for k, v in self.vocabulary.items()} def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] + if isinstance(token_ids[0], list): + return [ + "".join(map(self.inverse_vocabulary.get, token_ids_sublist)) + for token_ids_sublist in token_ids + ] + return [self.inverse_vocabulary[token_id] for token_id in token_ids] cfg_str = """ start: S S: "aa" | "bb" """ tokenizer = MockTokenizer() - fsm = CFGGuide(cfg_str, tokenizer) - instruction = fsm.get_next_instruction(fsm.start_state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2]) - assert fsm.reset_state # starting new regex - state = fsm.get_next_state(state=fsm.start_state, token_id=1) - assert fsm.generation == "a" - assert not fsm.is_final_state(state) + guide = CFGGuide(cfg_str, tokenizer) - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1]) - assert not fsm.reset_state # continuing current regex - state = fsm.get_next_state(state=state, token_id=1) - assert fsm.generation == "aa" - assert not fsm.is_final_state(state) + assert_expected_tensor_ids( + guide.get_next_instruction(guide.initial_state).tokens, [1, 2] + ) + state = guide.get_next_state(guide.initial_state, token_id=1) + assert not guide.must_terminate_state(state) + assert not guide.can_terminate_state(state) - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Write) - assert_expected_tensor_ids(instruction.tokens, [3]) - assert not fsm.reset_state # completing current regex - state = fsm.get_next_state(state=state, token_id=3) - assert fsm.generation == "aa" - assert fsm.is_final_state(state) + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [1]) + state = guide.get_next_state(state, token_id=1) + assert guide.must_terminate_state(state) + assert guide.can_terminate_state(state) + + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [3]) + state = guide.get_next_state(state, token_id=3) + assert guide.must_terminate_state(state) + assert guide.can_terminate_state(state) def test_cfg_allow_both_extend_and_shift_terminal(): @@ -443,46 +432,45 @@ def inverse_vocabulary(self): return {v: k for k, v in self.vocabulary.items()} def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] + if isinstance(token_ids[0], list): + return [ + "".join(map(self.inverse_vocabulary.get, token_ids_sublist)) + for token_ids_sublist in token_ids + ] + return [self.inverse_vocabulary[token_id] for token_id in token_ids] cfg_str = """ start: s s: "(" s ")" | /a+/ """ tokenizer = MockTokenizer() - fsm = CFGGuide(cfg_str, tokenizer) - instruction = fsm.get_next_instruction(fsm.start_state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 3]) - state = fsm.get_next_state(state=fsm.start_state, token_id=1) - assert fsm.generation == "(" - assert not fsm.is_final_state(state) + guide = CFGGuide(cfg_str, tokenizer) - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 3]) - state = fsm.get_next_state(state=state, token_id=3) - assert fsm.generation == "(a" - assert not fsm.is_final_state(state) + assert_expected_tensor_ids( + guide.get_next_instruction(guide.initial_state).tokens, [1, 3] + ) + state = guide.get_next_state(guide.initial_state, token_id=1) + assert not guide.must_terminate_state(state) + assert not guide.can_terminate_state(state) - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [2, 3]) - state = fsm.get_next_state(state=state, token_id=3) - assert fsm.generation == "(aa" - assert not fsm.is_final_state(state) + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [1, 3]) + state = guide.get_next_state(state, token_id=3) + assert not guide.must_terminate_state(state) + assert not guide.can_terminate_state(state) - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [2, 3]) - state = fsm.get_next_state(state=state, token_id=2) - assert fsm.generation == "(aa)" - assert not fsm.is_final_state(state) + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [2, 3]) + state = guide.get_next_state(state, token_id=3) + assert not guide.must_terminate_state(state) + assert not guide.can_terminate_state(state) - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Write) - assert_expected_tensor_ids(instruction.tokens, [4]) - state = fsm.get_next_state(state=state, token_id=4) - assert fsm.generation == "(aa)" - assert fsm.is_final_state(state) + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [2, 3]) + state = guide.get_next_state(state, token_id=2) + + assert guide.must_terminate_state(state) + assert guide.can_terminate_state(state) + + assert_expected_tensor_ids(guide.get_next_instruction(state).tokens, [4]) + state = guide.get_next_state(state, token_id=4) + assert guide.must_terminate_state(state) + assert guide.can_terminate_state(state) diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index a86d3c253..a7e9a696d 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -129,6 +129,17 @@ def sample_choices(): return ["foo", "bar", "baz"] +@pytest.fixture() +def sample_lark_grammar(): + # from https://github.com/lark-parser/lark/blob/master/docs/grammar.md + return """ + ?start: hello_world "!" number + hello_world: ("hello" | "world") ~ 3 + number: ("0".."9") ~ 5 + thanks: "Thank"i " for testing!" + """ + + REGEX_PATTERNS = [ "a b c d e", # ensure proper tokenizer whitespace prefix handling "(123456789)|(abcdefghijklmnop)", # ensure consistent correct sequence handling during batch @@ -151,6 +162,7 @@ def enforce_not_implemented(model_fixture, *task_names): "batch": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"], "beam_search": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"], "multiple_samples": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"], + "cfg": ["model_llamacpp"], # TODO: fix llama_cpp tokenizer } for task_name in task_names: if model_fixture in NOT_IMPLEMENTED.get(task_name, []): @@ -247,6 +259,28 @@ def test_generate_format_bool(request, model_fixture): assert isinstance(res, bool) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_cfg(request, model_fixture, sample_lark_grammar): + from lark import Lark + + from outlines import grammars + + model = request.getfixturevalue(model_fixture) + with enforce_not_implemented(model_fixture, "cfg"): + generator = generate.cfg(model, sample_lark_grammar) + res = generator(**get_inputs(model_fixture)) + # validate legal with the grammar via lark + # TODO: cleanup PartialLark so doesn't modify Lark globally + import importlib + + import lark.lark + + importlib.reload(lark.lark) + Lark( + sample_lark_grammar, parser="lalr", import_paths=[grammars.GRAMMAR_PATH] + ).parse(res) + + @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) def test_generate_text_stream(request, model_fixture): model = request.getfixturevalue(model_fixture) diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py index 0d1908ebd..3469dcbc0 100644 --- a/tests/generate/test_integration_llamacpp.py +++ b/tests/generate/test_integration_llamacpp.py @@ -5,7 +5,6 @@ from pydantic import BaseModel, constr import outlines.generate as generate -import outlines.grammars as grammars import outlines.models as models import outlines.samplers as samplers @@ -243,15 +242,6 @@ def test_llamacpp_json_schema(model): assert isinstance(result["bar"], str) -def test_llamacpp_cfg(model): - prompt = "<|im_start|>user\nOutput a short and valid JSON object with two keys.<|im_end|>\n><|im_start|>assistant\n" - - # remove this statement once cfg is implemented - with pytest.raises(NotImplementedError): - result = generate.cfg(model, grammars.arithmetic)(prompt, seed=11) - assert isinstance(result, str) - - @pytest.mark.parametrize( "repo,model_path,hf_tokenizer_uri", [ @@ -319,15 +309,15 @@ def test_RegexGuide_caching(model, temp_cache_dir): # These two different models and tokenizers should not have the same state # mapping results assert ( - generator.logits_processor.fsm.states_to_token_maps - != generator_2.logits_processor.fsm.states_to_token_maps + generator.logits_processor.guide.states_to_token_maps + != generator_2.logits_processor.guide.states_to_token_maps ) generator_3 = generate.regex(model_2, regex, sampler=samplers.greedy()) assert cache.stats() == (1, 2) assert ( - generator_2.logits_processor.fsm.states_to_token_maps - == generator_3.logits_processor.fsm.states_to_token_maps + generator_2.logits_processor.guide.states_to_token_maps + == generator_3.logits_processor.guide.states_to_token_maps ) # Just for fun... diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index f3fb9682e..cdd57e0c6 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -524,15 +524,15 @@ def test_RegexGuide_caching(temp_cache_dir): # These two different models and tokenizers should not have the same state # mapping results assert ( - generator.logits_processor.fsm.states_to_token_maps - != generator_2.logits_processor.fsm.states_to_token_maps + generator.logits_processor.guide.states_to_token_maps + != generator_2.logits_processor.guide.states_to_token_maps ) generator_3 = generate.regex(model_2, regex, sampler=greedy()) assert cache.stats() == (1, 2) assert ( - generator_2.logits_processor.fsm.states_to_token_maps - == generator_3.logits_processor.fsm.states_to_token_maps + generator_2.logits_processor.guide.states_to_token_maps + == generator_3.logits_processor.guide.states_to_token_maps ) # Just for fun...