Skip to content

Commit

Permalink
Escape special characters in JSON structure
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Sep 10, 2023
1 parent ce0fad4 commit 23299e0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
9 changes: 5 additions & 4 deletions outlines/text/json_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import json
import re
from typing import Dict

STRING = r'".*"'
Expand Down Expand Up @@ -192,7 +193,7 @@ def match_step_to_regex(step):
Parameters
----------
step:
A string that represents the schema's structure, or a dictionnary
A string that represents the schema's structure, or a dictionary
that represents a field in the schema.
Returns
Expand All @@ -203,13 +204,13 @@ def match_step_to_regex(step):
"""
match step:
case str() as step:
return step
return re.escape(step)

case {"enum": choices, "type": "string"}:
choices = [f'"{choice}"' for choice in choices]
choices = [f'"{re.escape(choice)}"' for choice in choices]
return f"({'|'.join(choices)})"
case {"enum": choices}:
choices = [str(choice) for choice in choices]
choices = [re.escape(str(choice)) for choice in choices]
return f"({'|'.join(choices)})"

case {"type": "array", "items": items}:
Expand Down
17 changes: 13 additions & 4 deletions tests/text/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,11 @@ def test_match_number(pattern, does_match):
'("Marc"|"Jean")',
[('"Marc"', True), ('"Jean"', True), ('"John"', False)],
),
(
{"title": "Foo", "enum": [".*", r"\s*"], "type": "string"},
r'("\.\*"|"\\s\*")',
[('".*"', True), (r'"\s*"', True), (r'"\.\*"', False)],
),
(
{"title": "Foo", "enum": [0, 1], "type": "integer"},
"(0|1)",
Expand All @@ -310,7 +315,7 @@ def test_match_number(pattern, does_match):
"type": "object",
"properties": {"count": {"title": "Count", "type": "integer"}},
},
'{\n "count": ' + INTEGER + "\n}",
'\\{\\\n\\ \\ "count":\\ ' + INTEGER + "\\\n\\}",
[('{\n "count": 100\n}', True)],
),
(
Expand Down Expand Up @@ -339,16 +344,20 @@ def test_match_number(pattern, does_match):
}
},
},
'{\n "fuzz": {\n "spam": ' + INTEGER + "\n }\n}",
'\\{\\\n\\ \\ "fuzz":\\ \\{\\\n\\ \\ \\ \\ "spam":\\ '
+ INTEGER
+ "\\\n\\ \\ \\}\\\n\\}",
[('{\n "fuzz": {\n "spam": 100\n }\n}', True)],
),
],
)
def test_match(step, regex, examples):
assert match_step_to_regex(step) == regex
test_regex = match_step_to_regex(step)

assert test_regex == regex

for string, does_match in examples:
match = re.fullmatch(regex, string)
match = re.fullmatch(test_regex, string)
if does_match:
assert match[0] == string
assert match.span() == (0, len(string))
Expand Down

0 comments on commit 23299e0

Please sign in to comment.