diff --git a/README.md b/README.md index f14c1e925..e9285e863 100644 --- a/README.md +++ b/README.md @@ -300,6 +300,33 @@ print(add(**result)) A great advantage of passing functions directly to specify the structure is that the structure of the LLM will change with the function's definition. No need to change the code at several places! +You can also embed various functions into an enum to generate params: + +```python +from enum import Enum +from functools import partial + +import outlines + + +def add(a: int, b: int) -> int: + return a + b + +def mul(c: float, d: float) -> float: + return c * d + +class Operation(Enum): + add = partial(add) + mul = partial(mul) + +model = outlines.models.transformers("WizardLM/WizardMath-7B-V1.1") +generator = outlines.generate.json(model, add) +result = generator("Return json with two float named c and d respectively. c is negative and d greater than 1.0.") + +print(result) +# {'c': -3.14, 'd': 1.5} +``` + ## Prompting Building prompts can get messy. **Outlines** makes it easier to write and manage diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 98d2de59c..0bab57923 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -2,6 +2,7 @@ import json import re import warnings +from enum import Enum from typing import Callable, Optional, Tuple, Type, Union from jsonschema.protocols import Validator @@ -306,6 +307,8 @@ def to_regex( for choice in instance["enum"]: if type(choice) in [int, float, bool, type(None), str]: choices.append(re.escape(json.dumps(choice))) + elif isinstance(choice, dict): + choices.append(to_regex(resolver, choice, whitespace_pattern)) else: raise TypeError(f"Unsupported data type in enum: {type(choice)}") return f"({'|'.join(choices)})" @@ -524,7 +527,7 @@ def to_regex( ) -def get_schema_from_signature(fn: Callable) -> str: +def get_schema_from_signature(fn: Callable) -> dict: """Turn a function signature into a JSON schema. Every JSON object valid to the output JSON Schema can be passed @@ -550,3 +553,16 @@ def get_schema_from_signature(fn: Callable) -> str: model = create_model(fn_name, **arguments) return model.model_json_schema() + + +def get_schema_from_enum(myenum: type[Enum]) -> dict: + if len(myenum) == 0: + raise ValueError( + f"Your enum class {myenum.__name__} has 0 members. If you are working with an enum of functions, do not forget to register them as callable (using `partial` for instance)" + ) + choices = [ + get_schema_from_signature(elt.value.func) if callable(elt.value) else elt.value + for elt in myenum + ] + schema = {"title": myenum.__name__, "enum": choices} + return schema diff --git a/outlines/generate/json.py b/outlines/generate/json.py index f75878d29..703447958 100644 --- a/outlines/generate/json.py +++ b/outlines/generate/json.py @@ -1,10 +1,15 @@ import json as pyjson +from enum import Enum from functools import singledispatch from typing import Callable, Optional, Union from pydantic import BaseModel -from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature +from outlines.fsm.json_schema import ( + build_regex_from_schema, + get_schema_from_enum, + get_schema_from_signature, +) from outlines.generate.api import SequenceGeneratorAdapter from outlines.models import OpenAI from outlines.samplers import Sampler, multinomial @@ -48,6 +53,11 @@ def json( regex_str = build_regex_from_schema(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: schema_object.parse_raw(x) + elif isinstance(schema_object, type(Enum)): + schema = pyjson.dumps(get_schema_from_enum(schema_object)) + regex_str = build_regex_from_schema(schema, whitespace_pattern) + generator = regex(model, regex_str, sampler) + generator.format_sequence = lambda x: pyjson.loads(x) elif callable(schema_object): schema = pyjson.dumps(get_schema_from_signature(schema_object)) regex_str = build_regex_from_schema(schema, whitespace_pattern) diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 7565ff642..6f0b59c50 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -1,5 +1,8 @@ import json import re +from contextlib import nullcontext +from enum import Enum +from functools import partial from typing import List, Literal, Union import interegular @@ -19,6 +22,7 @@ UUID, WHITESPACE, build_regex_from_schema, + get_schema_from_enum, get_schema_from_signature, to_regex, ) @@ -237,8 +241,26 @@ def test_match_number(pattern, does_match): ), # Enum mix of types ( - {"title": "Foo", "enum": [6, 5.3, "potato", True, None]}, - r'(6|5\.3|"potato"|true|null)', + { + "title": "Foo", + "enum": [ + 6, + 5.3, + "potato", + True, + None, + { + "properties": { + "a": {"title": "A", "type": "number"}, + "b": {"title": "B", "type": "number"}, + }, + "required": ["a", "b"], + "title": "add", + "type": "object", + }, + ], + }, + r'(6|5\.3|"potato"|true|null|\{[ ]?"a"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\.[0-9]+)?([eE][+-][0-9]+)?[ ]?,[ ]?"b"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\.[0-9]+)?([eE][+-][0-9]+)?[ ]?\})', [ ("6", True), ("5.3", True), @@ -248,6 +270,8 @@ def test_match_number(pattern, does_match): ("523", False), ("True", False), ("None", False), + ('{"a": -1.0, "b": 1.1}', True), + ('{"a": "a", "b": 1.1}', False), ], ), # integer @@ -1039,3 +1063,34 @@ class Model(BaseModel): # check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm() interegular.parse_pattern(pattern).to_fsm() + + +def add(a: float, b: float) -> float: + return a + b + + +class MyEnum(Enum): + add = partial(add) + a = "a" + b = 2 + + +# if you don't register your function as callable, you will get an empty enum +class EmptyEnum(Enum): + add = add + + +@pytest.mark.parametrize( + "enum,expectation", + [ + (MyEnum, nullcontext()), + (EmptyEnum, pytest.raises(ValueError)), + ], +) +def test_enum_schema(enum, expectation): + with expectation: + result = get_schema_from_enum(enum) + assert result["title"] == enum.__name__ + assert len(result["enum"]) == len(enum) + for elt in result["enum"]: + assert type(elt) in [int, float, bool, type(None), str, dict] diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index 8acb87500..92c5d789c 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -1,6 +1,7 @@ import datetime import re from enum import Enum +from functools import partial from typing import List, Union import pytest @@ -354,6 +355,29 @@ class User(BaseModel): assert result.user_id in [1, 2] +def add(a: int, b: int) -> int: + return a + b + + +def mul(c: float, d: float) -> float: + return c * d + + +def test_transformers_json_function_enum(model): + prompt = "Output some JSON " + + class Operation(Enum): + add = partial(add) + mul = partial(mul) + + result = generate.json(model, Operation)(prompt, seed=0) + assert isinstance(result, dict) + assert len(result) == 2 + for k, v in result.items(): + assert k in ["a", "b", "c", "d"] + assert isinstance(v, (int, float)) + + def test_transformers_json_array(model): prompt = "Output some JSON "