diff --git a/outlines/text/json_schema.py b/outlines/text/json_schema.py index 8b96c9dca..c076a2e4e 100644 --- a/outlines/text/json_schema.py +++ b/outlines/text/json_schema.py @@ -1,8 +1,10 @@ import itertools import json +import re from typing import Dict -STRING = r'".*"' +STRING_INNER = r'(?:[^"\\]|\\.)' +STRING = f'"{STRING_INNER}*"' INTEGER = r"(0|[1-9][0-9]*)" NUMBER = rf"(-)?({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?" BOOLEAN = r"(true|false)" @@ -192,7 +194,7 @@ def match_step_to_regex(step): Parameters ---------- step: - A string that represents the schema's structure, or a dictionnary + A string that represents the schema's structure, or a dictionary that represents a field in the schema. Returns @@ -203,13 +205,13 @@ def match_step_to_regex(step): """ match step: case str() as step: - return step + return re.escape(step) case {"enum": choices, "type": "string"}: - choices = [f'"{choice}"' for choice in choices] + choices = [f'"{re.escape(choice)}"' for choice in choices] return f"({'|'.join(choices)})" case {"enum": choices}: - choices = [str(choice) for choice in choices] + choices = [re.escape(str(choice)) for choice in choices] return f"({'|'.join(choices)})" case {"type": "array", "items": items}: @@ -224,9 +226,9 @@ def match_step_to_regex(step): return regex_str case {"type": "string", "maxLength": max_length}: - return f'".{{,{max_length}}}"' + return f'"{STRING_INNER}{{,{max_length}}}"' case {"type": "string", "minLength": min_length}: - return f'".{{{min_length},}}"' + return f'"{STRING_INNER}{{{min_length},}}"' case {"type": field_type}: return type_to_regex[field_type] diff --git a/tests/text/test_json_schema.py b/tests/text/test_json_schema.py index 4e3018e0d..f8814aeba 100644 --- a/tests/text/test_json_schema.py +++ b/tests/text/test_json_schema.py @@ -12,6 +12,7 @@ NULL, NUMBER, STRING, + STRING_INNER, build_schedule_from_schema, match_step_to_regex, ) @@ -258,13 +259,13 @@ def test_match_number(pattern, does_match): ), ( {"title": "Foo", "type": "string", "maxLength": 3}, - '".{,3}"', - [('"ab"', True), ('"abcd"', False)], + f'"{STRING_INNER}{{,3}}"', + [('"ab"', True), ('"a""', False), ('"abcd"', False)], ), ( {"title": "Foo", "type": "string", "minLength": 3}, - '".{3,}"', - [('"ab"', False), ('"abcd"', True)], + f'"{STRING_INNER}{{3,}}"', + [('"ab"', False), ('"abcd"', True), ('"abc""', False)], ), ( {"title": "Foo", "type": "boolean"}, @@ -290,6 +291,7 @@ def test_match_number(pattern, does_match): f"({STRING}|{NUMBER})", [ ('"string"', True), + ('"st"ring"', False), ("1000", True), ("true", False), ], @@ -299,6 +301,11 @@ def test_match_number(pattern, does_match): '("Marc"|"Jean")', [('"Marc"', True), ('"Jean"', True), ('"John"', False)], ), + ( + {"title": "Foo", "enum": [".*", r"\s*"], "type": "string"}, + r'("\.\*"|"\\s\*")', + [('".*"', True), (r'"\s*"', True), (r'"\.\*"', False)], + ), ( {"title": "Foo", "enum": [0, 1], "type": "integer"}, "(0|1)", @@ -310,7 +317,7 @@ def test_match_number(pattern, does_match): "type": "object", "properties": {"count": {"title": "Count", "type": "integer"}}, }, - '{\n "count": ' + INTEGER + "\n}", + '\\{\\\n\\ \\ "count":\\ ' + INTEGER + "\\\n\\}", [('{\n "count": 100\n}', True)], ), ( @@ -339,16 +346,20 @@ def test_match_number(pattern, does_match): } }, }, - '{\n "fuzz": {\n "spam": ' + INTEGER + "\n }\n}", + '\\{\\\n\\ \\ "fuzz":\\ \\{\\\n\\ \\ \\ \\ "spam":\\ ' + + INTEGER + + "\\\n\\ \\ \\}\\\n\\}", [('{\n "fuzz": {\n "spam": 100\n }\n}', True)], ), ], ) def test_match(step, regex, examples): - assert match_step_to_regex(step) == regex + test_regex = match_step_to_regex(step) + + assert test_regex == regex for string, does_match in examples: - match = re.fullmatch(regex, string) + match = re.fullmatch(test_regex, string) if does_match: assert match[0] == string assert match.span() == (0, len(string))