From f1e925f0b359237914ad5e48de69fc6b55426dd8 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sat, 16 Sep 2023 01:05:09 -0500 Subject: [PATCH] Fix whitespace and control character handling in JSON guidance --- outlines/text/json_schema.py | 25 ++++---- .../generate/test_integration_transfomers.py | 2 +- tests/text/test_json_schema.py | 62 +++++++++---------- 3 files changed, 43 insertions(+), 46 deletions(-) diff --git a/outlines/text/json_schema.py b/outlines/text/json_schema.py index c076a2e4e..7fffa0fa5 100644 --- a/outlines/text/json_schema.py +++ b/outlines/text/json_schema.py @@ -3,7 +3,7 @@ import re from typing import Dict -STRING_INNER = r'(?:[^"\\]|\\.)' +STRING_INNER = r'(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)' STRING = f'"{STRING_INNER}*"' INTEGER = r"(0|[1-9][0-9]*)" NUMBER = rf"(-)?({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?" @@ -142,7 +142,7 @@ def expand_json_schema(raw_schema: Dict, definitions: Dict): return raw_schema -def build_schedule_from_instance(instance: Dict, indent: int = 0): +def build_schedule_from_instance(instance: Dict): """Build a generation schedule from a instance. This recursively follows the references to other instances. @@ -163,27 +163,26 @@ def build_schedule_from_instance(instance: Dict, indent: int = 0): """ schedule = [] if "properties" in instance: - schedule.append("{\n") - schedule += build_schedule_from_instance(instance["properties"], indent + 2) - if indent > 0: - schedule.append(" " * indent) - schedule.append("}") + schedule.append(r"\{") + schedule += build_schedule_from_instance(instance["properties"]) + schedule.append(r"\}") else: for i, (name, annotation) in enumerate(instance.items()): - schedule.append(" " * indent) - schedule.append(f'"{name}": ') + whitespace = r"[\n ]*" + schedule.append(f'{whitespace}"{name}"{whitespace}:{whitespace}') + if "anyOf" in annotation: schedule.append(annotation) elif annotation["type"] == "object": - schedule += build_schedule_from_instance(annotation, indent) + schedule += build_schedule_from_instance(annotation) else: schedule.append(annotation) # We cannot add commas after the last key-value pair in JSON if i == len(instance) - 1: - schedule.append("\n") + schedule.append(whitespace) else: - schedule.append(",\n") + schedule.append(f"{whitespace},") return schedule @@ -205,7 +204,7 @@ def match_step_to_regex(step): """ match step: case str() as step: - return re.escape(step) + return step case {"enum": choices, "type": "string"}: choices = [f'"{re.escape(choice)}"' for choice in choices] diff --git a/tests/text/generate/test_integration_transfomers.py b/tests/text/generate/test_integration_transfomers.py index 6d73e3a8a..912e47758 100644 --- a/tests/text/generate/test_integration_transfomers.py +++ b/tests/text/generate/test_integration_transfomers.py @@ -136,7 +136,7 @@ class Spam(BaseModel): sequence = generate.json(model, Spam, max_tokens=1000)(prompt, rng=rng) parsed = json.loads(sequence) assert isinstance(parsed["foo"], int) - assert isinstance(parsed["bar"], float) + assert isinstance(parsed["bar"], int) assert isinstance(parsed["spam"], str) assert isinstance(parsed["fuzz"], bool) assert len(parsed["spam"]) == 10 diff --git a/tests/text/test_json_schema.py b/tests/text/test_json_schema.py index f8814aeba..68737788c 100644 --- a/tests/text/test_json_schema.py +++ b/tests/text/test_json_schema.py @@ -30,19 +30,19 @@ class User(BaseModel): schema = json.dumps(User.model_json_schema()) schedule = build_schedule_from_schema(schema) assert schedule == [ - '{\n "user_id": ', + '\\{[\\n ]*"user_id"[\\n ]*:[\\n ]*', {"title": "User Id", "type": "integer"}, - ',\n "name": ', + '[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*', {"title": "Name", "type": "string"}, - ',\n "maxlength_name": ', + '[\\n ]*,[\\n ]*"maxlength_name"[\\n ]*:[\\n ]*', {"title": "Maxlength Name", "type": "string", "maxLength": 10}, - ',\n "minlength_name": ', + '[\\n ]*,[\\n ]*"minlength_name"[\\n ]*:[\\n ]*', {"title": "Minlength Name", "type": "string", "minLength": 10}, - ',\n "value": ', + '[\\n ]*,[\\n ]*"value"[\\n ]*:[\\n ]*', {"title": "Value", "type": "number"}, - ',\n "is_true": ', + '[\\n ]*,[\\n ]*"is_true"[\\n ]*:[\\n ]*', {"title": "Is True", "type": "boolean"}, - "\n}", + "[\\n ]*\\}", ] @@ -53,9 +53,9 @@ class Foo(BaseModel): schema = json.dumps(Foo.model_json_schema()) schedule = build_schedule_from_schema(schema) assert schedule == [ - '{\n "bar": ', + '\\{[\\n ]*"bar"[\\n ]*:[\\n ]*', {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Bar"}, - "\n}", + "[\\n ]*\\}", ] @@ -67,11 +67,11 @@ class User(BaseModel): schema = json.dumps(User.model_json_schema()) schedule = build_schedule_from_schema(schema) assert schedule == [ - '{\n "user_id": ', + '\\{[\\n ]*"user_id"[\\n ]*:[\\n ]*', {"title": "User Id", "type": "integer"}, - ',\n "value": ', + '[\\n ]*,[\\n ]*"value"[\\n ]*:[\\n ]*', {"title": "Value", "type": "array", "items": {"type": "number"}}, - "\n}", + "[\\n ]*\\}", ] @@ -88,15 +88,15 @@ class User(BaseModel): schema = json.dumps(User.model_json_schema()) schedule = build_schedule_from_schema(schema) assert schedule == [ - '{\n "user_id": ', + '\\{[\\n ]*"user_id"[\\n ]*:[\\n ]*', {"title": "User Id", "type": "integer"}, - ',\n "name": ', + '[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*', { "title": "Name", "enum": ["John", "Marc", "Michel"], "type": "string", }, - "\n}", + "[\\n ]*\\}", ] @@ -122,15 +122,15 @@ class Spam(BaseModel): schema = json.dumps(Spam.model_json_schema()) schedule = build_schedule_from_schema(schema) assert schedule == [ - '{\n "foo": {\n "count": ', + '\\{[\\n ]*"foo"[\\n ]*:[\\n ]*\\{[\\n ]*"count"[\\n ]*:[\\n ]*', {"title": "Count", "type": "integer"}, - ',\n "size": {\n "buzz": ', + '[\\n ]*,[\\n ]*"size"[\\n ]*:[\\n ]*\\{[\\n ]*"buzz"[\\n ]*:[\\n ]*', {"title": "Buzz", "type": "string"}, - '\n }\n },\n "bars": {\n "apple": ', + '[\\n ]*\\}[\\n ]*\\}[\\n ]*,[\\n ]*"bars"[\\n ]*:[\\n ]*\\{[\\n ]*"apple"[\\n ]*:[\\n ]*', {"title": "Apple", "type": "string"}, - ',\n "banana": ', + '[\\n ]*,[\\n ]*"banana"[\\n ]*:[\\n ]*', {"title": "Banana", "type": "string"}, - "\n }\n}", + "[\\n ]*\\}[\\n ]*\\}", ] @@ -145,7 +145,7 @@ class Spam(BaseModel): schema = json.dumps(Spam.model_json_schema()) schedule = build_schedule_from_schema(schema) assert schedule == [ - '{\n "foo": ', + '\\{[\\n ]*"foo"[\\n ]*:[\\n ]*', { "items": { "title": "Foo", @@ -155,7 +155,7 @@ class Spam(BaseModel): "title": "Foo", "type": "array", }, - "\n}", + "[\\n ]*\\}", ] @@ -169,11 +169,11 @@ class Spam(BaseModel): schema = json.dumps(Spam.model_json_schema()) schedule = build_schedule_from_schema(schema) assert schedule == [ - '{\n "foo": ', + '\\{[\\n ]*"foo"[\\n ]*:[\\n ]*', {"title": "Foo", "type": "integer"}, - ',\n "bar": ', + '[\\n ]*,[\\n ]*"bar"[\\n ]*:[\\n ]*', {"title": "Bar", "anyOf": [{"type": "number"}, {"type": "string"}]}, - "\n}", + "[\\n ]*\\}", ] @@ -181,11 +181,11 @@ def test_json_schema(): schema = '{"title": "User", "type": "object", "properties": {"user_id": {"title": "User Id", "type": "integer"}, "name": {"title": "Name", "type": "string"}}, "required": ["user_id", "name"]}' schedule = build_schedule_from_schema(schema) assert schedule == [ - '{\n "user_id": ', + '\\{[\\n ]*"user_id"[\\n ]*:[\\n ]*', {"title": "User Id", "type": "integer"}, - ',\n "name": ', + '[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*', {"title": "Name", "type": "string"}, - "\n}", + "[\\n ]*\\}", ] @@ -317,7 +317,7 @@ def test_match_number(pattern, does_match): "type": "object", "properties": {"count": {"title": "Count", "type": "integer"}}, }, - '\\{\\\n\\ \\ "count":\\ ' + INTEGER + "\\\n\\}", + '\\{[\\n ]*"count"[\\n ]*:[\\n ]*(0|[1-9][0-9]*)[\\n ]*\\}', [('{\n "count": 100\n}', True)], ), ( @@ -346,9 +346,7 @@ def test_match_number(pattern, does_match): } }, }, - '\\{\\\n\\ \\ "fuzz":\\ \\{\\\n\\ \\ \\ \\ "spam":\\ ' - + INTEGER - + "\\\n\\ \\ \\}\\\n\\}", + f'\\{{[\\n ]*"fuzz"[\\n ]*:[\\n ]*\\{{[\\n ]*"spam"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*\\}}[\\n ]*\\}}', [('{\n "fuzz": {\n "spam": 100\n }\n}', True)], ), ],