Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solve 833 #29

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
9 changes: 8 additions & 1 deletion outlines/fsm/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from outlines.fsm.regex import (
fsm_union,
get_sub_fsms_from_seq,
get_token_transition_keys,
make_deterministic_fsm,
walk_fsm,
)
Expand Down Expand Up @@ -569,9 +570,15 @@ def match(self, text, pos, last_fsm_state_seq: Optional[Tuple[int, ...]] = None)

text_part = text[start_pos:]

text_transitions = get_token_transition_keys(
self.fsm.fsm_info.alphabet_symbol_mapping,
self.fsm.fsm_info.alphabet_anything_value,
text_part,
)

state_seq = walk_fsm(
self.fsm,
text_part,
text_transitions,
start_state,
full_match=self.match_whole,
)
Expand Down
125 changes: 89 additions & 36 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,11 @@ def fsm_info(self):
((k, z) for k, v in self.trans_key_to_states.items() for z in v),
dtype=np.dtype("int64, int64"),
)
alphabet_symbol_mapping_items = np.fromiter(
(
it
for it in self.alphabet._symbol_mapping.items()
if it[0] != anything_else
),
dtype=np.dtype("U2, int64"),
)
alphabet_symbol_mapping_items = [
(k, v)
for k, v in self.alphabet._symbol_mapping.items()
if k != anything_else
]
nb_finals = np.fromiter(self.finals, dtype=np.dtype("int64"))
self.__dict__["_fsm_info"] = create_fsm_info(
self.initial,
Expand All @@ -110,7 +107,7 @@ def fsm_info(self):

nb_int_list_type = numba.types.ListType(numba.int64)
nb_int_pair_type = numba.types.UniTuple(numba.int64, 2)
nb_unichar_2_type = numba.types.UnicodeCharSeq(2)
nb_unicode_type = numba.types.unicode_type


@numba.njit(cache=True)
Expand All @@ -136,7 +133,7 @@ def create_fsm_info(

# use 2-char strings so that we can represent incomplete utf-8 sequences
# as 2-hex-digit pairs
alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_2_type, numba.int64)
alphabet_symbol_map = numba.typed.Dict.empty(nb_unicode_type, numba.int64)
for symbol_and_trans_key in alphabet_symbol_mapping_items:
alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1]

Expand Down Expand Up @@ -199,7 +196,7 @@ def transition_trie_setdefault(


def byte_symbol(byte: int) -> str:
return f"{byte:02X}" if byte >= 0x80 else chr(byte)
return f"\x00{byte:02X}" if byte >= 0x80 else chr(byte)


def make_byte_level_fsm(fsm: FSM, keep_utf8=False) -> FSM:
Expand Down Expand Up @@ -415,21 +412,19 @@ def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
@numba.njit(nogil=True, cache=True)
def _walk_fsm(
fsm_transitions: Dict[Tuple[int, int], int],
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
input_string: Sequence[str],
token_transition_keys: Sequence[int],
start_state: int,
full_match: bool = True,
) -> List[int]:
state = start_state
accepted_states: List[int] = numba.typed.List.empty_list(numba.int64)
last_final_idx: int = numba.uint64(0)

for i, symbol in enumerate(input_string):
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)

# Iterate over token transition key sequence. The transition key
# sequence represents the FSM traversal rules of the tokens symbols.
for i, trans_key in enumerate(token_transition_keys):
new_state = fsm_transitions.get((state, trans_key))

if new_state is None:
Expand All @@ -453,7 +448,7 @@ def _walk_fsm(

def walk_fsm(
fsm: BetterFSM,
input_string: Sequence[str],
token_transition_keys: Sequence[int],
start_state: int,
full_match: bool = True,
) -> List[int]:
Expand All @@ -463,13 +458,11 @@ def walk_fsm(
accepted_states: List[int] = []
last_final_idx: int = 0

alphabet_symbol_mapping = fsm.alphabet._symbol_mapping
alphabet_anything_value = fsm.alphabet.anything_value
fsm_transitions = fsm.flat_transition_map

for i, symbol in enumerate(input_string):
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)

# Iterate over token transition key sequence. The transition key
# sequence represents the FSM traversal rules of the tokens symbols.
for i, trans_key in enumerate(token_transition_keys):
new_state = fsm_transitions.get((state, trans_key))

if new_state is None:
Expand Down Expand Up @@ -655,24 +648,25 @@ def state_scan_tokens(
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: List[Tuple[Sequence[str], Sequence[int]]],
vocabulary: List[Tuple[str, Sequence[int]]],
vocabulary_transition_keys: List[Sequence[int]],
start_state: int,
) -> Set[Tuple[int, int]]:
res = set()

for token, token_ids in vocabulary:
for (token, token_ids), token_transition_keys in zip(
vocabulary, vocabulary_transition_keys
):
state_seq = _walk_fsm(
fsm_transitions,
alphabet_symbol_mapping,
alphabet_anything_value,
fsm_initial,
fsm_finals,
token,
token_transition_keys,
start_state,
False,
)

if state_seq is not None and len(state_seq) < len(token):
if state_seq is not None and len(state_seq) < len(token_transition_keys):
continue

for token_id in token_ids:
Expand All @@ -681,9 +675,62 @@ def state_scan_tokens(
return res


@numba.njit(cache=True, nogil=True)
def get_token_transition_keys(
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
token_str: str,
) -> Sequence[int]:
"""
Get the sequence of transition keys for an individual string
with respect to an FSMs alphabet symbol mapping

This requires parsing the null-byte prefix rules of a byte-fsm:
- If two characters are prefixed by \x00, they are the grouped as a hex-byte
- Otherwise they are a standalone utf-8 character
"""
token_transition_keys = []
i = 0
while i < len(token_str):
if token_str[i] == "\x00" and i != len(token_str) - 1:
symbol = token_str[i : i + 3]
i += 3
else:
symbol = token_str[i]
i += 1

token_transition_keys.append(
alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
)

token_transition_keys_array = np.empty(len(token_transition_keys), dtype=np.int64)
for j in range(len(token_transition_keys)):
token_transition_keys_array[j] = token_transition_keys[j]
return token_transition_keys_array


@numba.njit(cache=True, nogil=True)
def get_vocabulary_transition_keys(
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
vocabulary: List[Tuple[str, Sequence[int]]],
) -> List[Sequence[int]]:
"""
Calculate the sequence transition keys for each token str within a vocabulary
"""
vocab_transition_keys = numba.typed.List.empty_list(numba.int64[:])
for token_str, _ in vocabulary:
token_transition_keys = get_token_transition_keys(
alphabet_symbol_mapping, alphabet_anything_value, token_str
)
vocab_transition_keys.append(token_transition_keys)

return vocab_transition_keys


def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: List[Tuple[Sequence[str], Sequence[int]]],
vocabulary: List[Tuple[str, Sequence[int]]],
) -> Dict[int, Set[Tuple[int, int]]]:
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""

Expand All @@ -699,6 +746,12 @@ def create_fsm_index_end_to_end(
desc="Compiling FSM index for all state transitions",
)

vocabulary_transition_keys = get_vocabulary_transition_keys(
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
vocabulary,
)

while next_states:
start_state = next_states.pop()

Expand All @@ -709,6 +762,7 @@ def create_fsm_index_end_to_end(
fsm_info.initial,
fsm_info.finals,
vocabulary,
vocabulary_transition_keys,
start_state,
)

Expand Down Expand Up @@ -771,7 +825,7 @@ def gpt2_unicode_to_bytes():
@lru_cache
def reduced_vocabulary(
tokenizer: "Tokenizer",
) -> Tuple[List[Tuple[Sequence[str], Sequence[int]]], Set[int]]:
) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]:
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
empty_token_ids = set()
vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {}
Expand Down Expand Up @@ -804,7 +858,7 @@ def reduced_vocabulary(
raise RuntimeError(
f"Cannot convert token `{token}` ({token_idx}) to bytes: {token_str}"
)
token_str = tuple(byte_symbol(b) for b in token_bytes)
token_str = "".join(byte_symbol(b) for b in token_bytes)

vocabulary.setdefault(token_str, []).append(token_idx)
else:
Expand All @@ -813,15 +867,14 @@ def reduced_vocabulary(
vocabulary_nb = numba.typed.List.empty_list(
numba.types.Tuple(
(
nb_unichar_2_type[:],
nb_unicode_type,
numba.int64[:],
)
)
)
for token_tuple, token_ids in vocabulary.items():
token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2"))
for token_str, token_ids in vocabulary.items():
token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64"))
vocabulary_nb.append((token_tuple_np, token_ids_np))
vocabulary_nb.append((token_str, token_ids_np))

return vocabulary_nb, empty_token_ids

Expand Down
17 changes: 9 additions & 8 deletions tests/fsm/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
from outlines.fsm.parsing import PartialLark, PartialPythonIndenter


def test_partial_parsing():
@pytest.fixture
def cleanup_lark_import():
yield
# Clean up lark.lark.LarkOptions._defaults
importlib.reload(lark.lark)


def test_partial_parsing(cleanup_lark_import):
lp = PartialLark.open_from_package(
"tests",
"partial_python.lark",
Expand Down Expand Up @@ -136,11 +143,8 @@ def test_partial_parsing():
assert len(parser_state.state_stack) == 4
assert parser_state.value_stack[-1].type == "LPAR"

# Clean up lark.lark.LarkOptions._defaults
importlib.reload(lark.lark)


def test_sequential_parse_example():
def test_sequential_parse_example(cleanup_lark_import):
input_tokens = [
"x ",
"= ",
Expand Down Expand Up @@ -200,6 +204,3 @@ def test_sequential_parse_example():

if i + 1 == len(input_tokens):
assert all(tk in next_vocab for tk in ["\n", "\nde", " ", " + 1"])

# Clean up lark.lark.LarkOptions._defaults
importlib.reload(lark.lark)
Loading
Loading