Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add json call with multi-function enums #1277

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,33 @@ print(add(**result))

A great advantage of passing functions directly to specify the structure is that the structure of the LLM will change with the function's definition. No need to change the code at several places!

You can also embed various functions into an enum to generate params:

```python
from enum import Enum
from functools import partial

import outlines


def add(a: int, b: int) -> int:
return a + b

def mul(c: float, d: float) -> float:
return c * d

class Operation(Enum):
add = partial(add)
mul = partial(mul)

model = outlines.models.transformers("WizardLM/WizardMath-7B-V1.1")
generator = outlines.generate.json(model, add)
result = generator("Return json with two float named c and d respectively. c is negative and d greater than 1.0.")

print(result)
# {'c': -3.14, 'd': 1.5}
```

## Prompting

Building prompts can get messy. **Outlines** makes it easier to write and manage
Expand Down
18 changes: 17 additions & 1 deletion outlines/fsm/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import re
import warnings
from enum import Enum
from typing import Callable, Optional, Tuple, Type, Union

from jsonschema.protocols import Validator
Expand Down Expand Up @@ -306,6 +307,8 @@ def to_regex(
for choice in instance["enum"]:
if type(choice) in [int, float, bool, type(None), str]:
choices.append(re.escape(json.dumps(choice)))
elif isinstance(choice, dict):
choices.append(to_regex(resolver, choice, whitespace_pattern))
else:
raise TypeError(f"Unsupported data type in enum: {type(choice)}")
return f"({'|'.join(choices)})"
Expand Down Expand Up @@ -524,7 +527,7 @@ def to_regex(
)


def get_schema_from_signature(fn: Callable) -> str:
def get_schema_from_signature(fn: Callable) -> dict:
"""Turn a function signature into a JSON schema.

Every JSON object valid to the output JSON Schema can be passed
Expand All @@ -550,3 +553,16 @@ def get_schema_from_signature(fn: Callable) -> str:
model = create_model(fn_name, **arguments)

return model.model_json_schema()


def get_schema_from_enum(myenum: type[Enum]) -> dict:
if len(myenum) == 0:
raise ValueError(
f"Your enum class {myenum.__name__} has 0 members. If you are working with an enum of functions, do not forget to register them as callable (using `partial` for instance)"
)
choices = [
get_schema_from_signature(elt.value.func) if callable(elt.value) else elt.value
for elt in myenum
]
schema = {"title": myenum.__name__, "enum": choices}
return schema
12 changes: 11 additions & 1 deletion outlines/generate/json.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import json as pyjson
from enum import Enum
from functools import singledispatch
from typing import Callable, Optional, Union

from pydantic import BaseModel

from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature
from outlines.fsm.json_schema import (
build_regex_from_schema,
get_schema_from_enum,
get_schema_from_signature,
)
from outlines.generate.api import SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.samplers import Sampler, multinomial
Expand Down Expand Up @@ -48,6 +53,11 @@ def json(
regex_str = build_regex_from_schema(schema, whitespace_pattern)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: schema_object.parse_raw(x)
elif isinstance(schema_object, type(Enum)):
schema = pyjson.dumps(get_schema_from_enum(schema_object))
regex_str = build_regex_from_schema(schema, whitespace_pattern)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: pyjson.loads(x)
elif callable(schema_object):
schema = pyjson.dumps(get_schema_from_signature(schema_object))
regex_str = build_regex_from_schema(schema, whitespace_pattern)
Expand Down
59 changes: 57 additions & 2 deletions tests/fsm/test_json_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import json
import re
from contextlib import nullcontext
from enum import Enum
from functools import partial
from typing import List, Literal, Union

import interegular
Expand All @@ -19,6 +22,7 @@
UUID,
WHITESPACE,
build_regex_from_schema,
get_schema_from_enum,
get_schema_from_signature,
to_regex,
)
Expand Down Expand Up @@ -237,8 +241,26 @@ def test_match_number(pattern, does_match):
),
# Enum mix of types
(
{"title": "Foo", "enum": [6, 5.3, "potato", True, None]},
r'(6|5\.3|"potato"|true|null)',
{
"title": "Foo",
"enum": [
6,
5.3,
"potato",
True,
None,
{
"properties": {
"a": {"title": "A", "type": "number"},
"b": {"title": "B", "type": "number"},
},
"required": ["a", "b"],
"title": "add",
"type": "object",
},
],
},
r'(6|5\.3|"potato"|true|null|\{[ ]?"a"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\.[0-9]+)?([eE][+-][0-9]+)?[ ]?,[ ]?"b"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\.[0-9]+)?([eE][+-][0-9]+)?[ ]?\})',
[
("6", True),
("5.3", True),
Expand All @@ -248,6 +270,8 @@ def test_match_number(pattern, does_match):
("523", False),
("True", False),
("None", False),
('{"a": -1.0, "b": 1.1}', True),
('{"a": "a", "b": 1.1}', False),
],
),
# integer
Expand Down Expand Up @@ -1039,3 +1063,34 @@ class Model(BaseModel):

# check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm()
interegular.parse_pattern(pattern).to_fsm()


def add(a: float, b: float) -> float:
return a + b


class MyEnum(Enum):
add = partial(add)
a = "a"
b = 2


# if you don't register your function as callable, you will get an empty enum
class EmptyEnum(Enum):
add = add


@pytest.mark.parametrize(
"enum,expectation",
[
(MyEnum, nullcontext()),
(EmptyEnum, pytest.raises(ValueError)),
],
)
def test_enum_schema(enum, expectation):
with expectation:
result = get_schema_from_enum(enum)
assert result["title"] == enum.__name__
assert len(result["enum"]) == len(enum)
for elt in result["enum"]:
assert type(elt) in [int, float, bool, type(None), str, dict]
24 changes: 24 additions & 0 deletions tests/generate/test_integration_transformers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import re
from enum import Enum
from functools import partial
from typing import List, Union

import pytest
Expand Down Expand Up @@ -354,6 +355,29 @@ class User(BaseModel):
assert result.user_id in [1, 2]


def add(a: int, b: int) -> int:
return a + b


def mul(c: float, d: float) -> float:
return c * d


def test_transformers_json_function_enum(model):
prompt = "Output some JSON "

class Operation(Enum):
add = partial(add)
mul = partial(mul)

result = generate.json(model, Operation)(prompt, seed=0)
assert isinstance(result, dict)
assert len(result) == 2
for k, v in result.items():
assert k in ["a", "b", "c", "d"]
assert isinstance(v, (int, float))


def test_transformers_json_array(model):
prompt = "Output some JSON "

Expand Down
Loading