diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a9039b605..de10b27cd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,4 +30,4 @@ repos: - id: mypy args: [--allow-redefinition] exclude: ^examples/ - additional_dependencies: [types-tqdm, types-Pillow] + additional_dependencies: [types-tqdm, types-Pillow, types-PyYAML] diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 98d2de59c..a33c9c924 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -1,9 +1,12 @@ +import dataclasses import inspect +import itertools import json import re import warnings -from typing import Callable, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +import yaml from jsonschema.protocols import Validator from pydantic import BaseModel, create_model from referencing import Registry, Resource @@ -20,28 +23,23 @@ NULL = r"null" WHITESPACE = r"[ ]?" -type_to_regex = { - "string": STRING, - "integer": INTEGER, - "number": NUMBER, - "boolean": BOOLEAN, - "null": NULL, -} - DATE_TIME = r'"(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\.[0-9]{3})?(Z)?"' DATE = r'"(?:\d{4})-(?:0[1-9]|1[0-2])-(?:0[1-9]|[1-2][0-9]|3[0-1])"' TIME = r'"(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\\.[0-9]+)?(Z)?"' UUID = r'"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"' -format_to_regex = { - "uuid": UUID, - "date-time": DATE_TIME, - "date": DATE, - "time": TIME, -} + +def load_yaml(yaml_str: str) -> Any: + """Parse a YAML string and return the corresponding Python object.""" + return yaml.safe_load(yaml_str) -def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = None): +def build_regex_from_schema( + schema: str, + whitespace_pattern: Optional[str] = None, + mode: str = "json", + strict_json_schema_subset: bool = True, +): """Turn a JSON schema into a regex that matches any JSON object that follows this schema. @@ -60,6 +58,12 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non whitespace_pattern Pattern to use for JSON syntactic whitespace (doesn't impact string literals) Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + mode + Either `json` or `yaml`, determines the structure of the generated output + strict_json_schema_subset + For `items` and `properties`, the JSON Schema spec by default allows these to be unconstrained + if not set. This is usually undesired behavior, so by default strict_json_schema_subset is True. + Returns ------- @@ -83,7 +87,21 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non resolver = registry.resolver() content = schema.contents - return to_regex(resolver, content, whitespace_pattern) + + if mode == "json": + return JSONSchemaRegexGenerator( + resolver, + whitespace_pattern, + strict_json_schema_subset=strict_json_schema_subset, + ).to_regex(content) + elif mode == "yaml": + return YAMLRegexGenerator( + resolver, + whitespace_pattern, + strict_json_schema_subset=strict_json_schema_subset, + ).to_regex(content) + else: + raise ValueError(f"invalid mode: {mode}") def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str: @@ -119,18 +137,6 @@ def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) - return schema_str -def _get_num_items_pattern(min_items, max_items, whitespace_pattern): - # Helper function for arrays and objects - min_items = int(min_items or 0) - if max_items is None: - return rf"{{{max(min_items - 1, 0)},}}" - else: - max_items = int(max_items) - if max_items < 1: - return None - return rf"{{{max(min_items - 1, 0)},{max_items - 1}}}" - - def validate_quantifiers( min_bound: Optional[str], max_bound: Optional[str], start_offset: int = 0 ) -> Tuple[str, str]: @@ -172,16 +178,58 @@ def validate_quantifiers( return min_bound, max_bound -def to_regex( - resolver: Resolver, instance: dict, whitespace_pattern: Optional[str] = None -): +def get_schema_from_signature(fn: Callable) -> str: + """Turn a function signature into a JSON schema. + + Every JSON object valid to the output JSON Schema can be passed + to `fn` using the ** unpacking syntax. + + """ + signature = inspect.signature(fn) + arguments = {} + for name, arg in signature.parameters.items(): + if arg.annotation == inspect._empty: + raise ValueError("Each argument must have a type annotation") + else: + arguments[name] = (arg.annotation, ...) + + try: + fn_name = fn.__name__ + except Exception as e: + fn_name = "Arguments" + warnings.warn( + f"The function name could not be determined. Using default name 'Arguments' instead. For debugging, here is exact error:\n{e}", + category=UserWarning, + ) + model = create_model(fn_name, **arguments) + + return model.model_json_schema() + + +@dataclasses.dataclass +class Context: + """Context for json schema rule application""" + + recursion_depth: int = 0 + nesting_level: int = 0 + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def increment(self, attr: str, value: int = 1) -> "Context": + """Return a **new** Context with the specified attribute incremented by `value`""" + return dataclasses.replace(self, **{attr: getattr(self, attr) + value}) + + def __repr__(self): + return f"Context({self.__dict__})" + + +class JSONSchemaRegexGenerator: """Translate a JSON Schema instance into a regex that validates the schema. Note ---- Many features of JSON schema are missing: - - Handle `additionalProperties` keyword - - Handle types defined as a list - Handle constraints on numbers - Handle special patterns: `date`, `uri`, etc. @@ -191,362 +239,707 @@ def to_regex( ---------- resolver An object that resolves references to other instances within a schema - instance - The instance to translate whitespace_pattern Pattern to use for JSON syntactic whitespace (doesn't impact string literals) Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + recursion_level + For unconstrained objects and lists ond many levels deep the pattern should be constructed. + strict_json_schema_subset + For `items` and `properties`, the JSON Schema spec by default allows these to be unconstrained + if not set. This is usually undesired behavior, so by default strict_json_schema_subset is True. """ - # set whitespace pattern - if whitespace_pattern is None: - whitespace_pattern = WHITESPACE + # Never impacted by parameters + STATIC_PRIMATIVES = {"boolean", "null"} + # Default value of primatives (when provided no parameters) + FORMAT_PRIMATIVE = { + "null": NULL, + "boolean": BOOLEAN, + "number": NUMBER, + "integer": INTEGER, + "string": STRING, + } + FORMAT_STRING = { + "uuid": UUID, + "date-time": DATE_TIME, + "date": DATE, + "time": TIME, + } + + def __init__( + self, + resolver: Resolver, + whitespace_pattern: Optional[str] = None, + max_nesting_level: int = 2, + strict_json_schema_subset: bool = True, + ): + self.resolver = resolver + self.ws = WHITESPACE if whitespace_pattern is None else whitespace_pattern + self.max_nesting_level = max_nesting_level + self.strict_json_schema_subset = strict_json_schema_subset + + def _validate_node(self, node: Any, ctx: Context): + """Validate the JSON Schema node for unsupported features and recursion limits.""" + if ctx.recursion_depth > 256: + raise NotImplementedError( + "Recursive schemas aren't currently available with Outlines." + ) - if instance == {}: - # JSON Schema Spec: Empty object means unconstrained, any json type is legal - types = [ - {"type": "boolean"}, - {"type": "null"}, - {"type": "number"}, - {"type": "integer"}, - {"type": "string"}, - {"type": "array"}, - {"type": "object"}, + if node is True: + return + + if node is False: + # this should be implemented along-side `visit_not(...)` + raise NotImplementedError("schema = False isn't available with Outlines.") + + # keys have no handling + not_implemented_keys = [ + "dependentSchemas", + "unevaluatedItems", + "unevaluatedProperties", + "contains", + "patternProperties", + "maximum", + "default", + "__proto__", ] - regexes = [to_regex(resolver, t, whitespace_pattern) for t in types] - regexes = [rf"({r})" for r in regexes] - return rf"{'|'.join(regexes)}" - - elif "properties" in instance: - regex = "" - regex += r"\{" - properties = instance["properties"] - required_properties = instance.get("required", []) - is_required = [item in required_properties for item in properties] - # If at least one property is required, we include the one in the lastest position - # without any comma. - # For each property before it (optional or required), we add with a comma after the property. - # For each property after it (optional), we add with a comma before the property. - if any(is_required): - last_required_pos = max([i for i, value in enumerate(is_required) if value]) - for i, (name, value) in enumerate(properties.items()): - subregex = f'{whitespace_pattern}"{re.escape(name)}"{whitespace_pattern}:{whitespace_pattern}' - subregex += to_regex(resolver, value, whitespace_pattern) - if i < last_required_pos: - subregex = f"{subregex}{whitespace_pattern}," - elif i > last_required_pos: - subregex = f"{whitespace_pattern},{subregex}" - regex += subregex if is_required[i] else f"({subregex})?" - # If no property is required, we have to create a possible pattern for each property in which - # it's the last one necessarilly present. Then, we add the others as optional before and after - # following the same strategy as described above. - # The whole block is made optional to allow the case in which no property is returned. - else: - property_subregexes = [] - for i, (name, value) in enumerate(properties.items()): - subregex = f'{whitespace_pattern}"{name}"{whitespace_pattern}:{whitespace_pattern}' - subregex += to_regex(resolver, value, whitespace_pattern) - property_subregexes.append(subregex) - possible_patterns = [] - for i in range(len(property_subregexes)): - pattern = "" - for subregex in property_subregexes[:i]: - pattern += f"({subregex}{whitespace_pattern},)?" - pattern += property_subregexes[i] - for subregex in property_subregexes[i + 1 :]: - pattern += f"({whitespace_pattern},{subregex})?" - possible_patterns.append(pattern) - regex += f"({'|'.join(possible_patterns)})?" - - regex += f"{whitespace_pattern}" + r"\}" - - return regex - - # To validate against allOf, the given data must be valid against all of the - # given subschemas. - elif "allOf" in instance: - subregexes = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["allOf"] - ] - subregexes_str = [f"{subregex}" for subregex in subregexes] - return rf"({''.join(subregexes_str)})" - - # To validate against `anyOf`, the given data must be valid against - # any (one or more) of the given subschemas. - elif "anyOf" in instance: - subregexes = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["anyOf"] + # keys coinciding within same object not handled + not_implemented_key_pairs = [ + ("allOf", "anyOf"), + ("properties", "anyOf"), ] - return rf"({'|'.join(subregexes)})" - # To validate against oneOf, the given data must be valid against exactly - # one of the given subschemas. - elif "oneOf" in instance: - subregexes = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["oneOf"] - ] + node_invalid_keys = set(node) & set(not_implemented_keys) + if node_invalid_keys: + raise NotImplementedError( + f"Cannot handle the keys: {node_invalid_keys}. Please open an Outlines issue." + ) + for k in not_implemented_key_pairs: + if not (set(k) - set(node.keys())): + raise NotImplementedError( + f"Cannot simultaneously use the keys: {k}. Please open an Outlines issue." + ) - xor_patterns = [f"(?:{subregex})" for subregex in subregexes] + def to_regex(self, node: Any, ctx: Optional[Context] = None): + """Convert a JSON Schema node into a regular expression pattern.""" + ctx = ( + ctx.increment("recursion_depth") + if ctx + else Context(nesting_level=0, recursion_depth=0) + ) + self._validate_node(node, ctx) + + # Handle unconstrained nodes + if node in ({}, True): + return self.visit_unconstrained({}, ctx) + + # Handle multiple types (via anyOf) + if isinstance(node.get("type"), list): + subpatterns = [self.to_regex({"type": t}, ctx) for t in node["type"]] + return self.format_anyOf(subpatterns) + + # Visit based on node attributes + node_attr_to_visitor = { + "$ref": self.visit_ref, + "allOf": self.visit_allOf, + "anyOf": self.visit_anyOf, + "oneOf": self.visit_oneOf, + "enum": self.visit_enum, + "prefixItems": self.visit_array, + "items": self.visit_array, + "properties": self.visit_object, + "const": self.visit_string, + "pattern": self.visit_string, + } + for attr, visitor in node_attr_to_visitor.items(): + if attr in node: + return ( + f"({visitor(node, ctx)})" + if ctx.nesting_level > 0 + else visitor(node, ctx) + ) - return rf"({'|'.join(xor_patterns)})" + # Visit based on type + type_to_visitor = { + "number": self.visit_number, + "integer": self.visit_integer, + "string": self.visit_string, + "object": self.visit_object, + "array": self.visit_array, + } + if node.get("type") in self.STATIC_PRIMATIVES: + return self.FORMAT_PRIMATIVE[node["type"]] + if node.get("type") in type_to_visitor: + return type_to_visitor[node["type"]](node, ctx) + + return self.visit_notimplemented(node, ctx) + + ########## + # VISITORS + ########## + def visit_ref(self, node: Any, ctx: Context): + path = node["$ref"] + if path == "#": + raise NotImplementedError("Recursive schemas aren't supported") + new_node = self.resolver.lookup(path).contents + return self.to_regex(new_node, ctx) + + def visit_object(self, node: Any, ctx: Context): + """ + Handle JSON Schema `object` rules + + additionalProperties handling: + pattern for json object with values defined by instance["additionalProperties"] + enforces value type constraints recursively, "minProperties", and "maxProperties" + doesn't enforce "required", "dependencies", "propertyNames" "any/all/on Of" + + TODO: the json-schema compliant implementation is as follows: + - properties and additionalProperties can both be set simultaneously + - min/maxProperties can be specified even if properties has constraints set + """ + value_ctx = ctx.increment("nesting_level") + + # TODO: handling for node["unevaluatedProperties"] + properties = node.get("properties", not self.strict_json_schema_subset) + properties = {} if properties is True else properties + required_properties = node.get("required", []) + additional_properties = node.get("additionalProperties") + + if properties and additional_properties: + raise NotImplementedError( + "`properties` & `additionalProperties != False` not implemented. Please open an Outlines issue." + ) - # Create pattern for Tuples, per JSON Schema spec, `prefixItems` determines types at each idx - elif "prefixItems" in instance: - element_patterns = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["prefixItems"] - ] - comma_split_pattern = rf"{whitespace_pattern},{whitespace_pattern}" - tuple_inner = comma_split_pattern.join(element_patterns) - return rf"\[{whitespace_pattern}{tuple_inner}{whitespace_pattern}\]" - - # The enum keyword is used to restrict a value to a fixed set of values. It - # must be an array with at least one element, where each element is unique. - elif "enum" in instance: - choices = [] - for choice in instance["enum"]: - if type(choice) in [int, float, bool, type(None), str]: - choices.append(re.escape(json.dumps(choice))) + elif properties and "minProperties" in node or "maxProperties" in node: + raise NotImplementedError( + "properties and minProperties / maxProperties not implemented. Please open an Outlines issue." + ) + + elif properties: + property_details = [ + { + "key_pattern": self.format_literal(name), + "value_pattern": self.to_regex(value, value_ctx), + "is_required": name in required_properties, + } + for name, value in properties.items() + ] + if any(pd["is_required"] for pd in property_details): + return self.format_object_with_required_properties( + property_details, ctx + ) else: - raise TypeError(f"Unsupported data type in enum: {type(choice)}") - return f"({'|'.join(choices)})" + return self.format_object_properties_all_optional(property_details, ctx) + + elif additional_properties is False: + return self.format_empty_object() - elif "const" in instance: - const = instance["const"] - if type(const) in [int, float, bool, type(None), str]: - const = re.escape(json.dumps(const)) else: - raise TypeError(f"Unsupported data type in const: {type(const)}") - return const - - elif "$ref" in instance: - path = f"{instance['$ref']}" - instance = resolver.lookup(path).contents - return to_regex(resolver, instance, whitespace_pattern) - - # The type keyword may either be a string or an array: - # - If it's a string, it is the name of one of the basic types. - # - If it is an array, it must be an array of strings, where each string is - # the name of one of the basic types, and each element is unique. In this - # case, the JSON snippet is valid if it matches any of the given types. - elif "type" in instance: - instance_type = instance["type"] - if instance_type == "string": - if "maxLength" in instance or "minLength" in instance: - max_items = instance.get("maxLength", "") - min_items = instance.get("minLength", "") - try: - if int(max_items) < int(min_items): - raise ValueError( - "maxLength must be greater than or equal to minLength" - ) # FIXME this raises an error but is caught right away by the except (meant for int("") I assume) - except ValueError: - pass - return f'"{STRING_INNER}{{{min_items},{max_items}}}"' - elif "pattern" in instance: - pattern = instance["pattern"] - if pattern[0] == "^" and pattern[-1] == "$": - return rf'("{pattern[1:-1]}")' - else: - return rf'("{pattern}")' - elif "format" in instance: - format = instance["format"] - if format == "date-time": - return format_to_regex["date-time"] - elif format == "uuid": - return format_to_regex["uuid"] - elif format == "date": - return format_to_regex["date"] - elif format == "time": - return format_to_regex["time"] - else: - raise NotImplementedError( - f"Format {format} is not supported by Outlines" - ) + if additional_properties in (True, None): + value_pattern = self.visit_unconstrained(node, value_ctx) else: - return type_to_regex["string"] - - elif instance_type == "number": - bounds = { - "minDigitsInteger", - "maxDigitsInteger", - "minDigitsFraction", - "maxDigitsFraction", - "minDigitsExponent", - "maxDigitsExponent", - } - if bounds.intersection(set(instance.keys())): - min_digits_integer, max_digits_integer = validate_quantifiers( - instance.get("minDigitsInteger"), - instance.get("maxDigitsInteger"), - start_offset=1, - ) - min_digits_fraction, max_digits_fraction = validate_quantifiers( - instance.get("minDigitsFraction"), instance.get("maxDigitsFraction") - ) - min_digits_exponent, max_digits_exponent = validate_quantifiers( - instance.get("minDigitsExponent"), instance.get("maxDigitsExponent") - ) - integers_quantifier = ( - f"{{{min_digits_integer},{max_digits_integer}}}" - if min_digits_integer or max_digits_integer - else "*" - ) - fraction_quantifier = ( - f"{{{min_digits_fraction},{max_digits_fraction}}}" - if min_digits_fraction or max_digits_fraction - else "+" - ) - exponent_quantifier = ( - f"{{{min_digits_exponent},{max_digits_exponent}}}" - if min_digits_exponent or max_digits_exponent - else "+" - ) - return rf"((-)?(0|[1-9][0-9]{integers_quantifier}))(\.[0-9]{fraction_quantifier})?([eE][+-][0-9]{exponent_quantifier})?" - return type_to_regex["number"] + # Object with arbitrary key name, constrained value + value_pattern = self.to_regex(additional_properties, value_ctx) + + return self.format_object_with_additional_properties( + value_pattern, + ctx, + min_properties=node.get("minProperties"), + max_properties=node.get("maxProperties"), + ) - elif instance_type == "integer": - if "minDigits" in instance or "maxDigits" in instance: - min_digits, max_digits = validate_quantifiers( - instance.get("minDigits"), instance.get("maxDigits"), start_offset=1 - ) - return rf"(-)?(0|[1-9][0-9]{{{min_digits},{max_digits}}})" - return type_to_regex["integer"] + def visit_array(self, node: Any, ctx: Context): + """Handle JSON Schema `array` rules with optional item constraints.""" - elif instance_type == "array": - num_repeats = _get_num_items_pattern( - instance.get("minItems"), instance.get("maxItems"), whitespace_pattern + # TODO: handling for node["unevaluatedItems"] + # TODO: handling for node["additionalItems"] + # TODO: handling for node["uniqueItems"] + if "uniqueItems" in node: + raise NotImplementedError( + "uniqueItems is not implemented. Please open an Outlines issue." ) - if num_repeats is None: - return rf"\[{whitespace_pattern}\]" - allow_empty = "?" if int(instance.get("minItems", 0)) == 0 else "" + elem_ctx = ctx.increment("nesting_level") + + items = node.get("items", not self.strict_json_schema_subset) - if "items" in instance: - items_regex = to_regex(resolver, instance["items"], whitespace_pattern) - return rf"\[{whitespace_pattern}(({items_regex})(,{whitespace_pattern}({items_regex})){num_repeats}){allow_empty}{whitespace_pattern}\]" + if node.get("prefixItems") is not None: + # `prefixItems` determines types at each idx, which precedes `items` rules + if items in (True, None): + suffix_elem_pattern = self.visit_unconstrained(node, elem_ctx) + elif items is False: + suffix_elem_pattern = None else: - # Here we need to make the choice to exclude generating list of objects - # if the specification of the object is not given, even though a JSON - # object that contains an object here would be valid under the specification. - legal_types = [ - {"type": "boolean"}, - {"type": "null"}, - {"type": "number"}, - {"type": "integer"}, - {"type": "string"}, - ] - depth = instance.get("depth", 2) - if depth > 0: - legal_types.append({"type": "object", "depth": depth - 1}) - legal_types.append({"type": "array", "depth": depth - 1}) - - regexes = [ - to_regex(resolver, t, whitespace_pattern) for t in legal_types - ] - return rf"\[{whitespace_pattern}({'|'.join(regexes)})(,{whitespace_pattern}({'|'.join(regexes)})){num_repeats}{allow_empty}{whitespace_pattern}\]" - - elif instance_type == "object": - # pattern for json object with values defined by instance["additionalProperties"] - # enforces value type constraints recursively, "minProperties", and "maxProperties" - # doesn't enforce "required", "dependencies", "propertyNames" "any/all/on Of" - num_repeats = _get_num_items_pattern( - instance.get("minProperties"), - instance.get("maxProperties"), - whitespace_pattern, + suffix_elem_pattern = self.to_regex(items, elem_ctx) + + prefix_subpatterns = [ + self.to_regex(item, elem_ctx) for item in node["prefixItems"] + ] + return self.format_prefixItems(prefix_subpatterns, ctx, suffix_elem_pattern) + + else: + # handle simple case: no prefix items + if node.get("items") in (True, None): # noqa + items_regex = self.visit_unconstrained(node, elem_ctx) + else: + items_regex = self.to_regex(node["items"], elem_ctx) + return self.format_array( + items_regex, ctx, node.get("minItems"), node.get("maxItems") ) - if num_repeats is None: - return rf"\{{{whitespace_pattern}\}}" - - allow_empty = "?" if int(instance.get("minProperties", 0)) == 0 else "" - - additional_properties = instance.get("additionalProperties") - - if additional_properties is None or additional_properties is True: - # JSON Schema behavior: If the additionalProperties of an object is - # unset or True, it is unconstrained object. - # We handle this by setting additionalProperties to anyOf: {all types} - - legal_types = [ - {"type": "string"}, - {"type": "number"}, - {"type": "boolean"}, - {"type": "null"}, - ] - - # We set the object depth to 2 to keep the expression finite, but the "depth" - # key is not a true component of the JSON Schema specification. - depth = instance.get("depth", 2) - if depth > 0: - legal_types.append({"type": "object", "depth": depth - 1}) - legal_types.append({"type": "array", "depth": depth - 1}) - additional_properties = {"anyOf": legal_types} - - value_pattern = to_regex( - resolver, additional_properties, whitespace_pattern + + def visit_number(self, node: Any, ctx: Context): + quantifier_keys = [ + "minDigitsInteger", + "maxDigitsInteger", + "minDigitsFraction", + "maxDigitsFraction", + "minDigitsExponent", + "maxDigitsExponent", + ] + if any([qk in node for qk in quantifier_keys]): + min_digits_integer, max_digits_integer = validate_quantifiers( + node.get("minDigitsInteger"), + node.get("maxDigitsInteger"), + start_offset=1, ) - key_value_pattern = ( - f"{STRING}{whitespace_pattern}:{whitespace_pattern}{value_pattern}" + min_digits_fraction, max_digits_fraction = validate_quantifiers( + node.get("minDigitsFraction"), node.get("maxDigitsFraction") ) - key_value_successor_pattern = ( - f"{whitespace_pattern},{whitespace_pattern}{key_value_pattern}" + min_digits_exponent, max_digits_exponent = validate_quantifiers( + node.get("minDigitsExponent"), node.get("maxDigitsExponent") ) - multiple_key_value_pattern = f"({key_value_pattern}({key_value_successor_pattern}){num_repeats}){allow_empty}" - - return ( - r"\{" - + whitespace_pattern - + multiple_key_value_pattern - + whitespace_pattern - + r"\}" + return self.format_number_range( + min_digits_integer, + max_digits_integer, + min_digits_fraction, + max_digits_fraction, + min_digits_exponent, + max_digits_exponent, ) + else: + return self.FORMAT_PRIMATIVE["number"] - elif instance_type == "boolean": - return type_to_regex["boolean"] + def visit_integer(self, node: Any, ctx: Context): + if "maxDigits" in node or "minDigits" in node: + min_digits, max_digits = validate_quantifiers( + node.get("minDigits"), node.get("maxDigits"), start_offset=1 + ) + return self.format_integer_range(min_digits, max_digits) + else: + return self.FORMAT_PRIMATIVE["integer"] + + def visit_string(self, node: Any, ctx: Context): + if "const" in node: + return self.format_literal(node["const"]) + if "maxLength" in node or "minLength" in node: + min_length, max_length = validate_quantifiers( + node.get("minLength"), node.get("maxLength") + ) + return self.format_string_length(min_length, max_length) + elif "pattern" in node: + return self.format_string_pattern(node["pattern"]) + elif "format" in node: + return self.format_string_format(node["format"]) + return self.FORMAT_PRIMATIVE["string"] + + def visit_enum(self, node: Any, ctx: Context): + """ + The enum keyword is used to restrict a value to a fixed set of values. It + must be an array with at least one element, where each element is unique. + """ + choices = [self.format_literal(choice) for choice in node["enum"]] + return self.format_anyOf(choices) + + def visit_allOf(self, node: Any, ctx: Context): + subpatterns = [self.to_regex(subschema, ctx) for subschema in node["allOf"]] + return self.format_allOf(subpatterns) + + def visit_anyOf(self, node: Any, ctx: Context): + subpatterns = [self.to_regex(subschema, ctx) for subschema in node["anyOf"]] + return self.format_anyOf(subpatterns) + + def visit_oneOf(self, node: Any, ctx: Context): + subpatterns = [self.to_regex(subschema, ctx) for subschema in node["oneOf"]] + return self.format_oneOf(subpatterns) + + def visit_notimplemented(self, node: Any, ctx: Context): + raise NotImplementedError( + f"Handler for node `{node}` is not implemented. Please open an Outlines issue." + ) - elif instance_type == "null": - return type_to_regex["null"] + ############ + # FORMATTERS + ############ + def format_number_range( + self, + min_digits_integer, + max_digits_integer, + min_digits_fraction, + max_digits_fraction, + min_digits_exponent, + max_digits_exponent, + ): + integers_quantifier = ( + f"{{{min_digits_integer},{max_digits_integer}}}" + if min_digits_integer or max_digits_integer + else "*" + ) + fraction_quantifier = ( + f"{{{min_digits_fraction},{max_digits_fraction}}}" + if min_digits_fraction or max_digits_fraction + else "+" + ) + exponent_quantifier = ( + f"{{{min_digits_exponent},{max_digits_exponent}}}" + if min_digits_exponent or max_digits_exponent + else "+" + ) + return rf"((-)?(0|[1-9][0-9]{integers_quantifier}))(\.[0-9]{fraction_quantifier})?([eE][+-][0-9]{exponent_quantifier})?" - elif isinstance(instance_type, list): - # Here we need to make the choice to exclude generating an object - # if the specification of the object is not give, even though a JSON - # object that contains an object here would be valid under the specification. - regexes = [ - to_regex(resolver, {"type": t}, whitespace_pattern) - for t in instance_type - if t != "object" - ] - return rf"({'|'.join(regexes)})" + def format_integer_range(self, min_digits=None, max_digits=None): + if min_digits or max_digits: + num_items_pattern = f"{{{min_digits},{max_digits}}}" + else: + num_items_pattern = "*" - raise NotImplementedError( - f"""Could not translate the instance {instance} to a - regular expression. Make sure it is valid to the JSON Schema specification. If - it is, please open an issue on the Outlines repository""" - ) + return rf"(-)?(0|[1-9][0-9]{num_items_pattern})" + def format_string_length(self, min_length, max_length): + return f'"{STRING_INNER}{{{min_length},{max_length}}}"' -def get_schema_from_signature(fn: Callable) -> str: - """Turn a function signature into a JSON schema. + def format_string_pattern(self, pattern: str): + if pattern[0] == "^" and pattern[-1] == "$": + pattern_string_inner = pattern[1:-1] + else: + pattern_string_inner = pattern + return f'"{pattern_string_inner}"' + + def format_string_format(self, fmt: str): + format_regex = self.FORMAT_STRING.get(fmt) + if format_regex: + return format_regex + raise NotImplementedError( + f"Format {fmt} is not supported. Please open an Outlines issue." + ) - Every JSON object valid to the output JSON Schema can be passed - to `fn` using the ** unpacking syntax. + def format_property_kv( + self, key_pattern: str, value_pattern: str, ctx: Context + ) -> str: + return f"{self.ws}{key_pattern}{self.ws}(:){self.ws}{value_pattern}" - """ - signature = inspect.signature(fn) - arguments = {} - for name, arg in signature.parameters.items(): - if arg.annotation == inspect._empty: - raise ValueError("Each argument must have a type annotation") + def format_empty_object(self): + return r"\{" + self.ws + r"\}" + + def format_object_properties_all_optional( + self, property_details: List[Dict], ctx: Context + ): + property_subregexes = [ + self.format_property_kv(pd["key_pattern"], pd["value_pattern"], ctx) + for pd in property_details + ] + possible_patterns = [ + f"{self.ws},".join(combination) + for i in range(1, len(property_subregexes) + 1) + for combination in itertools.combinations(property_subregexes, i) + ] + inner = f"({'|'.join(possible_patterns)})?" + return r"\{" + inner + self.ws + r"\}" + + def format_object_with_required_properties( + self, property_details: List[Dict], ctx: Context + ): + is_required = [prop["is_required"] for prop in property_details] + last_required_pos = max(i for i, value in enumerate(is_required) if value) + inner = "" + for i, pd in enumerate(property_details): + subregex = self.format_property_kv( + pd["key_pattern"], pd["value_pattern"], ctx + ) + if i < last_required_pos: + subregex = f"{subregex}{self.ws}," + elif i > last_required_pos: + subregex = f"{self.ws},{subregex}" + inner += subregex if is_required[i] else f"({subregex})?" + return r"\{" + inner + self.ws + r"\}" + + def format_object_with_additional_properties( + self, value_pattern: str, ctx: Context, min_properties=None, max_properties=None + ): + inner = self._regex_repeat_elem( + elem_pattern=f"({STRING}){self.ws}(:){self.ws}({value_pattern})", + separator_pattern=f"{self.ws},{self.ws}", + min_elem=min_properties, + max_elem=max_properties, + pad=self.ws, + ) + return r"\{" + inner + r"\}" + + def format_array( + self, elem_pattern: str, ctx: Context, min_items=None, max_items=None + ): + inner = self._regex_repeat_elem( + elem_pattern=elem_pattern, + separator_pattern=f"{self.ws},{self.ws}", + min_elem=min_items, + max_elem=max_items, + pad=self.ws, + ) + return rf"\[{inner}\]" + + def format_prefixItems( + self, + prefix_patterns: List[str], + ctx: Context, + suffix_elem_pattern: Optional[str] = None, + ): + comma_split_pattern = rf"{self.ws},{self.ws}" + prefix_pattern = f"{self.ws}{comma_split_pattern.join(prefix_patterns)}" + if suffix_elem_pattern: + suffix_pattern = self._regex_repeat_elem( + elem_pattern=suffix_elem_pattern, + separator_pattern=f"{self.ws},{self.ws}", + min_elem=1, + pad=self.ws, + ) + suffix_pattern = f"((,{suffix_pattern})|)" + inner = f"{prefix_pattern}{suffix_pattern}" else: - arguments[name] = (arg.annotation, ...) + inner = prefix_pattern + self.ws + return rf"\[{inner}\]" + + def format_literal(self, literal: Any): + if isinstance(literal, str): + return f"{re.escape(json.dumps(literal))}" + if type(literal) in [int, bool, type(None)]: + return re.escape(json.dumps(literal)) + elif isinstance(literal, float): + if float(literal) == int(literal): + int_literal = re.escape(json.dumps(int(literal))) + float_literal = re.escape(json.dumps(float(literal))) + return f"({int_literal}|{float_literal})" + else: + return re.escape(json.dumps(literal)) + else: + raise NotImplementedError( + f"Unsupported data type in literal: {type(literal)}. Please open an Outlines issue." + ) - try: - fn_name = fn.__name__ - except Exception as e: - fn_name = "Arguments" + def format_allOf(self, patterns: List[str]): + return ( + "(" + "".join([f"(?={pat})" for pat in patterns[:-1]]) + patterns[-1] + ")" + ) + + def format_anyOf(self, patterns: List[str]): + return "(" + "|".join([f"({pat})" for pat in patterns]) + ")" + + def format_oneOf(self, patterns: List[str]): + # If you're searching "NotImplementedError", this method also needs to be properly implemented! warnings.warn( - f"The function name could not be determined. Using default name 'Arguments' instead. For debugging, here is exact error:\n{e}", - category=UserWarning, + "JSON Schema `oneOf` not implemented. Using `anyOf` instead. Please open an Outlines Issue." ) - model = create_model(fn_name, **arguments) + return self.format_anyOf(patterns) - return model.model_json_schema() + def visit_unconstrained(self, node: Any, ctx: Context): + legal_types = [ + {"type": "boolean"}, + {"type": "null"}, + {"type": "number"}, + {"type": "integer"}, + {"type": "string"}, + ] + allowed_nesting = node.get( + "_allowed_nesting", ctx.nesting_level + self.max_nesting_level + ) + # We limit the object depth to keep the expression finite, but the "depth" + # key is not a true component of the JSON Schema specification. + if ctx.nesting_level < allowed_nesting: + legal_types.append({"type": "object", "_allowed_nesting": allowed_nesting}) + legal_types.append({"type": "array", "_allowed_nesting": allowed_nesting}) + + subpatterns = [self.to_regex(t, ctx) for t in legal_types] + return self.format_anyOf(subpatterns) + + def _regex_repeat_elem( + self, + elem_pattern: str, + separator_pattern: str, + min_elem=None, + max_elem=None, + pad="", + ): + """ + Creates a pattern allowing between min_elem and max_elem occurrences of elem_pattern + Ensures each element pattern is separated by separator_pattern + Surrounds result with `pad` + """ + if str(max_elem) == "0": + return pad + + base_pattern = f"({elem_pattern})" + suffix_pattern = f"(({separator_pattern})({elem_pattern}))" + + min_suffix_repeats = "" if min_elem is None else max(0, int(min_elem) - 1) + max_suffix_repeats = "" if max_elem is None else max_elem - 1 + + if str(max_suffix_repeats) == "0": + pattern = base_pattern + else: + pattern = f"{base_pattern}({suffix_pattern}){{{min_suffix_repeats},{max_suffix_repeats}}}" + + padded_pattern = f"({pad}{pattern}{pad})" + + if not min_elem: + return f"({padded_pattern}|{pad})" + else: + return padded_pattern + + +class YAMLRegexGenerator(JSONSchemaRegexGenerator): + """ + Core differences between JSON and YAML + -------------------------------------- + + For most types including `boolean`, `null`, `number`, and `integer` + YAML supports a superset of JSON representation. For example, `boolean` can + be `true` / `false` like JSON, however it can also be `yes` / `no`. For these + types we will limit generation to the valid JSON-representation subset. + + ``` + string: + - Equivalent to JSON, but doesn't use quotes + + array: + - In YAML arrays are represented + - by newline separated + - dash-prefixed array elements + + object: + - An object is represented as a newline separated list of key: value pairs + ``` + """ + + FORMAT_PRIMATIVE = { + # yaml allows for more escape types + "string": r'([^"\\\x00-\x08\x0B\x0C\x0E-\x1F\x7F-\x9F]|\\["\\])', + **JSONSchemaRegexGenerator.FORMAT_PRIMATIVE, + } + + @staticmethod + def _indentation(nesting_level: int): + return r"(\n)" + (f"[ ]{{{nesting_level * 2}}}" if nesting_level else "") + + def format_property_kv( + self, key_pattern: str, value_pattern: str, ctx: Context + ) -> str: + """ + Similar to JSON property kv, but with changes to accomodate yaml rules: + - leading spaces are not allowed as the spaces are syntactic + - `foo:bar` isn't a legal kv, + - need a single space, e.g. `foo: bar` + - or an indented newline, e.g. `foo:\n bar` + """ + child_indentation = self._indentation(ctx.nesting_level + 1) + return f"({key_pattern}{self.ws}(:){child_indentation}{value_pattern})" + + def format_object_properties_all_optional( + self, property_details: List[Dict], ctx: Context + ): + property_subregexes = [ + self.format_property_kv(pd["key_pattern"], pd["value_pattern"], ctx) + for pd in property_details + ] + indentation = self._indentation(ctx.nesting_level) + possible_patterns = [ + indentation.join(combination) # first indent is optional + for i in range(1, len(property_subregexes) + 1) + for combination in itertools.combinations(property_subregexes, i) + ] + one_or_more_pattern = "|".join(possible_patterns) + return f"({one_or_more_pattern}|{self.format_empty_object()})" + + def format_object_with_required_properties( + self, property_details: List[Dict], ctx: Context + ): + is_required = [prop["is_required"] for prop in property_details] + last_required_pos = max(i for i, value in enumerate(is_required) if value) + + indentation = self._indentation(ctx.nesting_level) + + inner = "" + for i, pd in enumerate(property_details): + subregex = self.format_property_kv( + pd["key_pattern"], pd["value_pattern"], ctx + ) + if i < last_required_pos: + subregex = f"{subregex}{indentation}" + elif i > last_required_pos: + subregex = f"{indentation}{subregex}" + inner += subregex if is_required[i] else f"({subregex})?" + + return inner + + def format_object_with_additional_properties( + self, value_pattern: str, ctx: Context, min_properties=None, max_properties=None + ): + if min_properties in (0, "0", "", None): + min_properties = 0 + + inner = self._regex_repeat_elem( + elem_pattern=self.format_property_kv(STRING, value_pattern, ctx), + separator_pattern=self._indentation(ctx.nesting_level), + min_elem=max(1, min_properties), + max_elem=max_properties, + ) + if min_properties == 0: + empty_obj_pattern = self.format_empty_object() + return f"({inner})|({empty_obj_pattern})" + + return inner + + def format_array( + self, elem_pattern: str, ctx: Context, min_items=None, max_items=None + ): + self_indentation = self._indentation(ctx.nesting_level) + + child_indentation = self._indentation(ctx.nesting_level + 1) + child_separator = f"([ ]|({child_indentation}))" + + if min_items in (0, "0", "", None): + min_items = 0 + + inner = self._regex_repeat_elem( + elem_pattern=f"(-){child_separator}{elem_pattern}", + separator_pattern=self_indentation, + min_elem=max(1, min_items), + max_elem=max_items, + ) + if min_items == 0: + empty_list_pattern = r"(\[\])" + return f"({inner})|({empty_list_pattern})" + return inner + + def format_prefixItems( + self, + prefix_patterns: List[str], + ctx: Context, + suffix_elem_pattern: Optional[str] = None, + ): + self_indentation = self._indentation(ctx.nesting_level) + + child_indentation = self._indentation(ctx.nesting_level + 1) + child_separator = f"([ ]|({child_indentation}))" + + prefix_pattern = self_indentation.join( + [f"(-){child_separator}{pat}" for pat in prefix_patterns] + ) + + if suffix_elem_pattern: + suffix_pattern = self._regex_repeat_elem( + elem_pattern=f"(-){child_separator}{suffix_elem_pattern}", + separator_pattern=self_indentation, + min_elem=1, + ) + suffix_pattern = f"({self_indentation}{suffix_pattern})?" + return f"{prefix_pattern}{suffix_pattern}" + else: + return prefix_pattern diff --git a/outlines/generate/json.py b/outlines/generate/json.py index f75878d29..9cd063531 100644 --- a/outlines/generate/json.py +++ b/outlines/generate/json.py @@ -4,7 +4,11 @@ from pydantic import BaseModel -from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature +from outlines.fsm.json_schema import ( + build_regex_from_schema, + get_schema_from_signature, + load_yaml, +) from outlines.generate.api import SequenceGeneratorAdapter from outlines.models import OpenAI from outlines.samplers import Sampler, multinomial @@ -18,6 +22,7 @@ def json( schema_object: Union[str, object, Callable], sampler: Sampler = multinomial(), whitespace_pattern: Optional[str] = None, + mode="json", ) -> SequenceGeneratorAdapter: """ Generate structured JSON data with a `Transformer` model based on a specified JSON Schema. @@ -36,6 +41,8 @@ def json( whitespace_pattern Pattern to use for JSON syntactic whitespace (doesn't impact string literals) Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + mode + Either `json` or `yaml`, determines the structure of the generated output Returns ------- @@ -43,21 +50,26 @@ def json( transforms the result if BaseModel is used. """ + if mode == "yaml": + to_json = lambda x: pyjson.dumps(load_yaml(x)) + else: + to_json = lambda x: x + if isinstance(schema_object, type(BaseModel)): schema = pyjson.dumps(schema_object.model_json_schema()) - regex_str = build_regex_from_schema(schema, whitespace_pattern) + regex_str = build_regex_from_schema(schema, whitespace_pattern, mode=mode) generator = regex(model, regex_str, sampler) - generator.format_sequence = lambda x: schema_object.parse_raw(x) + generator.format_sequence = lambda x: schema_object.parse_raw(to_json(x)) elif callable(schema_object): schema = pyjson.dumps(get_schema_from_signature(schema_object)) - regex_str = build_regex_from_schema(schema, whitespace_pattern) + regex_str = build_regex_from_schema(schema, whitespace_pattern, mode=mode) generator = regex(model, regex_str, sampler) - generator.format_sequence = lambda x: pyjson.loads(x) + generator.format_sequence = lambda x: pyjson.loads(to_json(x)) elif isinstance(schema_object, str): schema = schema_object - regex_str = build_regex_from_schema(schema, whitespace_pattern) + regex_str = build_regex_from_schema(schema, whitespace_pattern, mode=mode) generator = regex(model, regex_str, sampler) - generator.format_sequence = lambda x: pyjson.loads(x) + generator.format_sequence = lambda x: pyjson.loads(to_json(x)) else: raise ValueError( f"Cannot parse schema {schema_object}. The schema must be either " diff --git a/pyproject.toml b/pyproject.toml index ab3ecd775..ed0773bc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ test = [ "torch", "transformers", "pillow", + "requests_cache", ] serve = [ "vllm>=0.3.0", @@ -136,6 +137,7 @@ module = [ "pycountry.*", "airportsdata.*", "outlines_core.*", + "requests_cache.*", ] ignore_missing_imports = true diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 7565ff642..627ef59cc 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -1,9 +1,11 @@ +import collections import json import re from typing import List, Literal, Union import interegular import pytest +import yaml from pydantic import BaseModel, Field, constr from outlines.fsm.json_schema import ( @@ -20,10 +22,129 @@ WHITESPACE, build_regex_from_schema, get_schema_from_signature, - to_regex, ) +def assert_patterns_equivalent( + generated_pattern, expected_pattern, n_diff=0, allow_both=False +): + gen_fsm = interegular.parse_pattern(generated_pattern).to_fsm() + expect_fsm = interegular.parse_pattern(expected_pattern).to_fsm() + if gen_fsm.reduce() != expect_fsm.reduce(): + if n_diff: + to_str = lambda s: "".join([c if isinstance(c, str) else "{*}" for c in s]) + only_generated = [ + to_str(s) + for _, s in zip(range(n_diff), gen_fsm.difference(expect_fsm).strings()) + ] + only_expected = [ + to_str(s) + for _, s in zip(range(n_diff), expect_fsm.difference(gen_fsm).strings()) + ] + additional_details = ( + f"Accepted only by generated pattern (max {n_diff}): {only_generated}\n" + f"Accepted only by expected pattern (max {n_diff}): {only_expected}\n" + ) + if allow_both: + both = [ + to_str(s) + for _, s in zip(range(n_diff), (gen_fsm & expect_fsm).strings()) + ] + additional_details += ( + f"Accepted by both patterns (max {n_diff}): {both}\n" + ) + else: + additional_details = "" + + raise ValueError( + "Patterns Not Equivalent:\n" + f"generated_pattern = {generated_pattern}\n" + f" expected_pattern = {expected_pattern}\n" + f"{additional_details}" + ) + + +def dump_yaml_normalized(data): + """ + yaml can represent the same data in many different ways. + + This function creates a normalized yaml dump which ensures + - strings are always represented with quotes + - OrderedDict is represented without !!python/object/apply:collections.OrderedDict + - End of document signifier "\n...\n" is removed + - Standardize Indentation Behavior + """ + + # handle confusion in yaml dumper + if isinstance(data, str): + return json.dumps(data) + + class NormalizedDumper(yaml.Dumper): + def increase_indent(self, flow=False, indentless=False): + return super().increase_indent(flow, False) + + def quoted_str_presenter(dumper, data): + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style='"') + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + # Ensure strings are always quoted + NormalizedDumper.add_representer(str, quoted_str_presenter) + # Ensure OrderedDict is represented without !!python/object/apply + NormalizedDumper.add_representer(collections.OrderedDict, dict_representer) + + dumped = yaml.dump( + data, Dumper=NormalizedDumper, default_flow_style=False, sort_keys=False + ).rstrip("\n...\n") + + # hack to normalize formatting to our yaml subset + return re.sub( + r"^([ \t-]*)([^:\s]+): (\S.*)", + lambda m: f"{m.group(1)}{m.group(2)}:\n{' ' * (len(m.group(1)) + 2)}{m.group(3)}", + dumped, + flags=re.MULTILINE, + ) + + +def assert_match_expectation(json_sample, pattern, does_match, schema, mode="json"): + """ + Ensure sample conforms to `does_match` expectation + - check sample normally if in json mode + - convert sample to normalized yaml if in yaml mode + """ + # if yaml mode, convert to yaml if possible, otherwise succeed the test + if mode == "yaml": + try: + if json.dumps(json.loads(json_sample)) != json_sample: + return + except json.decoder.JSONDecodeError: + return + + sample = dump_yaml_normalized( + json.loads(json_sample, object_pairs_hook=collections.OrderedDict) + ) + + # ensure yaml wasn't corrupted by rstrip + assert yaml.safe_load(sample) == json.loads( + json_sample + ), "invalid test, json -> yaml inconsistent" + + else: + sample = json_sample + + match = re.fullmatch(pattern, sample) + if does_match: + if not match: + raise ValueError( + f"Expected match for sample:\n{sample}\n\n" + f"Schema: {json.dumps(json.loads(schema), indent=4)}\n" + f"Generated Pattern: {repr(pattern)}\n" + ) + else: + assert match is None + + def test_function_basic(): def test_function(foo: str, bar: List[int]): pass @@ -71,7 +192,7 @@ class User(BaseModel): ) def test_match_integer(pattern, does_match): step = {"title": "Foo", "type": "integer"} - regex = to_regex(None, step) + regex = build_regex_from_schema(json.dumps(step)) assert regex == INTEGER value = pattern["integer"] @@ -98,7 +219,7 @@ def test_match_integer(pattern, does_match): ) def test_match_number(pattern, does_match): step = {"title": "Foo", "type": "number"} - regex = to_regex(None, step) + regex = build_regex_from_schema(json.dumps(step)) assert regex == NUMBER value = pattern["number"] @@ -420,7 +541,7 @@ def test_match_number(pattern, does_match): # array ( {"title": "Foo", "type": "array", "items": {"type": "number"}}, - rf"\[{WHITESPACE}(({NUMBER})(,{WHITESPACE}({NUMBER})){{0,}})?{WHITESPACE}\]", + rf"\[(({WHITESPACE}({NUMBER})((?:{WHITESPACE},{WHITESPACE}({NUMBER}))){{,}}{WHITESPACE})|{WHITESPACE})\]", [("[1e+9,1.3]", True), ("[]", True), ("[1", False)], ), # array with a set length of 1 @@ -444,7 +565,7 @@ def test_match_number(pattern, does_match): "minItems": 3, "maxItems": 3, }, - rf"\[{WHITESPACE}(({INTEGER})(,{WHITESPACE}({INTEGER})){{2,2}}){WHITESPACE}\]", + rf"\[({WHITESPACE}({INTEGER})((?:{WHITESPACE},{WHITESPACE}({INTEGER}))){{2,2}}{WHITESPACE})\]", [("[1]", False), ("[]", False), ("[1,2,3]", True), ("[1,2,3,4]", False)], ), # array with length 0 @@ -473,7 +594,7 @@ def test_match_number(pattern, does_match): }, "required": ["test_dict"], }, - rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{STRING}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{STRING}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", + rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}(\{{({WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}({STRING})({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}({STRING})){{0,}})?{WHITESPACE}\}}){WHITESPACE}\}}""", [ ("""{ "test_dict":{"foo":"bar","baz": "bif"}}""", True), ("""{ "test_dict":{"foo":"bar" }}""", True), @@ -499,7 +620,7 @@ def test_match_number(pattern, does_match): }, "required": ["test_dict"], }, - rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", + rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}(\{{({WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}\{{({WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}\{{({WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}){{0,}})?{WHITESPACE}\}}){WHITESPACE}\}}""", [ ( """{"test_dict": {"foo": {"bar": 123, "apple": 99}, "baz": {"bif": 456}}}""", @@ -544,14 +665,48 @@ def test_match_number(pattern, does_match): rf"({STRING}|{INTEGER})", [("12", True), ('"a"', True), ('1"a"', False)], ), + # oneOf: TODO: currently implemented as anyOf, uncomment when proper oneOf is implemented + # ( + # { + # "title": "Foo", + # "oneOf": [{"type": "string", "format": "date"}, {"type": "string", "pattern": "2024.*"}], + # }, + # rf"TODO", + # [('"2024-01-07"', False), ('"2024-01-01"', True), ('"2024foobar7"', True), ('"2024-neither"', False)], + # ), + # anyOf + ( + { + "title": "Foo", + "anyOf": [ + {"type": "string", "format": "date"}, + {"type": "string", "pattern": "2024.*7"}, + ], + }, + rf'({DATE})|("2024.*7")', + [ + ('"2024-01-07"', True), + ('"2024-01-01"', True), + ('"2024foobar7"', True), + ('"2024-neither"', False), + ], + ), # allOf ( { "title": "Foo", - "allOf": [{"type": "string"}, {"type": "integer"}], + "allOf": [ + {"type": "string", "format": "date"}, + {"type": "string", "pattern": "2024.*7"}, + ], }, - rf"({STRING}{INTEGER})", - [('"a"1', True), ('"a"', False), ('"1"', False)], + rf'(?=({DATE}))("2024.*7")', + [ + ('"2024-01-07"', True), + ('"2024-01-01"', False), + ('"2024foobar7"', False), + ('"2024-neither"', False), + ], ), # Tuple / prefixItems ( @@ -748,21 +903,24 @@ def test_match_number(pattern, does_match): ), ], ) -def test_match(schema, regex, examples): - interegular.parse_pattern(regex) +@pytest.mark.parametrize("mode", ["json", "yaml"]) +def test_match(schema, regex, examples, mode): schema = json.dumps(schema) - test_regex = build_regex_from_schema(schema) - assert test_regex == regex + generated_pattern = build_regex_from_schema(schema, mode=mode) + + if mode == "json": + # patterns assert equivalence of pattern behavior to expectation + assert_patterns_equivalent( + generated_pattern=generated_pattern, expected_pattern=regex + ) + + # ensure pattern can be parsed by interegular + interegular.parse_pattern(regex) for string, does_match in examples: - match = re.fullmatch(test_regex, string) - if does_match: - if match is None: - raise ValueError(f"Expected match for '{string}'") - assert match[0] == string - assert match.span() == (0, len(string)) - else: - assert match is None + assert_match_expectation( + string, generated_pattern, does_match, schema, mode=mode + ) @pytest.mark.parametrize( @@ -827,19 +985,17 @@ def test_match(schema, regex, examples): ), ], ) -def test_format(schema, regex, examples): +@pytest.mark.parametrize("mode", ["json", "yaml"]) +def test_format(schema, regex, examples, mode): interegular.parse_pattern(regex) schema = json.dumps(schema) - test_regex = build_regex_from_schema(schema) - assert test_regex == regex + generated_pattern = build_regex_from_schema(schema, mode=mode) + assert generated_pattern == regex for string, does_match in examples: - match = re.fullmatch(test_regex, string) - if does_match: - assert match[0] == string - assert match.span() == (0, len(string)) - else: - assert match is None + assert_match_expectation( + string, generated_pattern, does_match, schema, mode=mode + ) @pytest.mark.parametrize( @@ -976,16 +1132,14 @@ def test_format(schema, regex, examples): ), ], ) -def test_format_without_regex(schema, examples): +@pytest.mark.parametrize("mode", ["json", "yaml"]) +def test_format_without_regex(schema, examples, mode): schema = json.dumps(schema) - test_regex = build_regex_from_schema(schema) + generated_pattern = build_regex_from_schema(schema, mode=mode) for string, does_match in examples: - match = re.fullmatch(test_regex, string) - if does_match: - assert match[0] == string - assert match.span() == (0, len(string)) - else: - assert match is None + assert_match_expectation( + string, generated_pattern, does_match, schema, mode=mode + ) @pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]*", "abc"]) @@ -1017,6 +1171,7 @@ class MockModel(BaseModel): assert re.fullmatch(pattern, mock_result_mult_ws) +@pytest.mark.skip("oneOf not implemented") def test_one_of_doesnt_produce_illegal_lookaround(): """Reproduces failure in https://github.com/dottxt-ai/outlines/issues/823""" @@ -1039,3 +1194,28 @@ class Model(BaseModel): # check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm() interegular.parse_pattern(pattern).to_fsm() + + +def test_all_generations_legal(): + """ + # Array of literal {"k": "v"} + ( + { + "type": "array", + "items": { + "type": "object", + "properties": { + "k": {"const": "v"} + } + }, + "required": ["k"], + "additionalProperties": False + }, + [ + ("1", True), + ] + ), + """ + # TODO: check all fsm.strings() matches the schema + # patch STRING, INTEGER, and NUMBER so they have limited length + pass diff --git a/tests/fsm/test_json_schema_full.py b/tests/fsm/test_json_schema_full.py new file mode 100644 index 000000000..eecb09d38 --- /dev/null +++ b/tests/fsm/test_json_schema_full.py @@ -0,0 +1,130 @@ +# LOCAL IMPORT HACK +import importlib +import json +import re + +import pytest +import requests +import requests_cache +import yaml +from referencing.exceptions import Unresolvable + +from outlines.fsm.json_schema import build_regex_from_schema + +dump_yaml_normalized = importlib.import_module("test_json_schema").dump_yaml_normalized + + +requests_cache.install_cache("test_request_cache", expire_after=3600) + + +def get_json_schema_tests_from_repo( + repo="json-schema-org/JSON-Schema-Test-Suite", configs_dir="tests/draft2020-12" +): + api_url = f"https://api.github.com/repos/{repo}/contents/{configs_dir}" + headers = {"Accept": "application/vnd.github.v3+json"} + response = requests.get(api_url, headers=headers) + response.raise_for_status() + contents = response.json() + + results = [] + for item in contents: + if item["type"] == "file" and item["name"].endswith(".json"): + file_url = item["download_url"] + file_response = requests.get(file_url) + file_response.raise_for_status() + json_data = file_response.json() + + for entry in json_data: + for test in entry["tests"]: + results.append( + { + "file": item["name"], + "schema": json.dumps(entry["schema"]), + "data": json.dumps(test["data"]), + "is_valid": test["valid"], + } + ) + + return results + + +@pytest.mark.skip("Utility for improving compliance with json schema spec") +@pytest.mark.parametrize("sample", get_json_schema_tests_from_repo()) +def test_json_schema_to_json_compliance(sample): + """ + Assert that we either correctly handle a schema, or skip if NotImplementedError + """ + try: + pattern = build_regex_from_schema( + sample["schema"], strict_json_schema_subset=False + ) + except NotImplementedError as e: + pytest.skip(f"{e}") + except Unresolvable: + pytest.xfail() + + if sample["is_valid"]: + assert ( + re.fullmatch(pattern, sample["data"]) is not None + ), "Failed to match valid schema" + else: + assert ( + re.fullmatch(pattern, sample["data"]) is None + ), "Incorrectly matched invalid schema" + + +@pytest.mark.parametrize("sample", get_json_schema_tests_from_repo()) +def test_json_schema_to_yaml_compliance(sample): + """ + Skip tests checking whether it can be built, that is covered by + `test_json_schema_compliance`. + + Here we are purely testing whether yaml is valid IFF json schema is valid. + """ + # skip if the test fails to construct the pattern + try: + json_pattern = build_regex_from_schema( + sample["schema"], strict_json_schema_subset=False + ) + except Exception: + pytest.skip() + + # skip invalid with json + json_valid = re.fullmatch(json_pattern, sample["data"]) is not None + if sample["is_valid"] != json_valid: + pytest.skip() + + # valide yaml generation for samples for all samples where json is valid + try: + yaml_pattern = build_regex_from_schema( + sample["schema"], mode="yaml", strict_json_schema_subset=False + ) + except NotImplementedError as e: + pytest.skip(f"{e}") + except Unresolvable: + pytest.xfail() + + yaml_sample = dump_yaml_normalized(json.loads(sample["data"])) + + # xfail complex mappings + if any(line.startswith("? ") for line in yaml_sample.split("\n")): + pytest.xfail() + + if sample["is_valid"]: + assert ( + re.fullmatch(yaml_pattern, yaml_sample) is not None + ), "Failed to match valid schema" + else: + assert ( + re.fullmatch(yaml_pattern, yaml_sample) is None + ), "Incorrectly matched invalid schema" + + +@pytest.mark.skip() +@pytest.mark.parametrize("sample", get_json_schema_tests_from_repo()) +def test_yaml_dumper_consistency(sample): + """valide output yaml is equivalent to input json""" + sample_from_json = json.loads(sample["data"]) + sample_from_yaml = yaml.safe_load(dump_yaml_normalized(sample_from_json)) + assert sample_from_yaml == sample_from_json + assert json.dumps(sample_from_yaml) == sample["data"] diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index a96ce8673..4ab30383d 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -116,12 +116,15 @@ def model_t5(tmp_path_factory): @pytest.fixture() def sample_schema(): - from pydantic import BaseModel, conint, conlist, constr + from typing import Tuple + + from pydantic import BaseModel, constr + + from outlines.types import countries class SampleSchema(BaseModel): title: constr(max_length=10) - numbers: conlist(conint(strict=True), min_length=3, max_length=3) - labels: conlist(constr(min_length=1, max_length=5), min_length=3, max_length=3) + tup: Tuple[constr(min_length=1, max_length=5), countries.Name, countries.Flag] return SampleSchema @@ -234,13 +237,13 @@ def test_generate_fsm(request, model_fixture, pattern): assert re.fullmatch(pattern, res) is not None, res -@pytest.mark.skip( - "Fix issues with JSON, some models fail this test https://github.com/dottxt-ai/outlines/issues/985" -) @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) -def test_generate_json(request, model_fixture, sample_schema): +@pytest.mark.parametrize("mode", ["json", "yaml"]) +def test_generate_json(request, model_fixture, sample_schema, mode): + if model_fixture in ("model_transformers_random", "model_bart"): + pytest.skip("model vocabulary insufficient for test") model = request.getfixturevalue(model_fixture) - generator = generate.json(model, sample_schema) + generator = generate.json(model, sample_schema, mode=mode) # asserts valid within call generator(**get_inputs(model_fixture), max_tokens=100)