Skip to content

Commit

Permalink
Add to_regex method to the different types
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 29, 2024
1 parent 8c2e1fa commit c2fb5cb
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
34 changes: 34 additions & 0 deletions outlines/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from pydantic import BaseModel, TypeAdapter
from typing_extensions import _TypedDictMeta # type: ignore

from outlines.fsm.json_schema import build_regex_from_schema

from . import airports, countries
from .email import Email
from .isbn import ISBN
Expand All @@ -30,6 +32,7 @@ class Json:
"""

definition: Union[str, dict]
whitespace_pattern: str = " "

def to_json_schema(self):
if isinstance(self.definition, str):
Expand All @@ -52,11 +55,21 @@ def to_json_schema(self):

return schema

def to_regex(self):
schema = self.to_json_schema()
schema_str = json.dumps(schema)
return build_regex_from_schema(schema_str, self.whitespace_pattern)


@dataclass
class List:
definition: list

def to_regex(self):
raise NotImplementedError(
"Structured generation for lists of objects are not implemented yet."
)


@dataclass
class Choice:
Expand All @@ -67,3 +80,24 @@ class Choice:
def __post_init__(self):
if isinstance(self.definition, list):
self.definition = Enum("Definition", [(x, x) for x in self.definition])

def to_list(self):
if isinstance(self.definition, list):
return self.definition
else:
return [x.value for x in self.definition]

def to_regex(self):
choices = self.to_list()
regex_str = r"(" + r"|".join(choices) + r")"
return regex_str


@dataclass
class Regex:
"""Represents a string defined by a regular expression."""

definition: str

def to_regex(self):
return self.definition
12 changes: 12 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ def test_type_choice():
choice_type = types.Choice(choices)
assert choice_type.definition.a.value == "a"

regex_str = choice_type.to_regex()
assert regex_str == "(a|b)"


def test_type_list():
class Foo(BaseModel):
bar: int

list_type = types.List(Foo)
with pytest.raises(NotImplementedError, match="Structured"):
list_type.to_regex()


@pytest.mark.parametrize(
"custom_type,test_string,should_match",
Expand Down

0 comments on commit c2fb5cb

Please sign in to comment.