diff --git a/outlines/fsm/types.py b/outlines/fsm/types.py index cddcd163f..5695dee07 100644 --- a/outlines/fsm/types.py +++ b/outlines/fsm/types.py @@ -32,10 +32,12 @@ def custom_format_fn(sequence: str) -> Any: if isinstance(python_type, EnumMeta): values = python_type.__members__.keys() - regex_str = "(" + "|".join(values) + ")" - format_fn = lambda x: str(x) + enum_regex_str: str = "(" + "|".join(values) + ")" - return regex_str, format_fn + def enum_format_fn(sequence: str) -> str: + return str(sequence) + + return enum_regex_str, enum_format_fn if python_type == float: diff --git a/outlines/types/__init__.py b/outlines/types/__init__.py index 7af7f296f..266d3a68e 100644 --- a/outlines/types/__init__.py +++ b/outlines/types/__init__.py @@ -1,4 +1,4 @@ +from . import airports, countries from .isbn import ISBN from .phone_numbers import PhoneNumber from .zip_codes import ZipCode -from . import airports diff --git a/pyproject.toml b/pyproject.toml index c4d29142d..8a7cc1460 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,6 +126,7 @@ module = [ "vllm.*", "uvicorn.*", "fastapi.*", + "pycountry.*", "pyairports.*", ] ignore_missing_imports = true diff --git a/tests/test_types.py b/tests/test_types.py index 8143d3d6d..2391ccc18 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -47,10 +47,19 @@ class Model(BaseModel): [ (types.airports.IATA, "CDG", True), (types.airports.IATA, "XXX", False), + (types.countries.Alpha2, "FR", True), + (types.countries.Alpha2, "XX", False), + (types.countries.Alpha3, "UKR", True), + (types.countries.Alpha3, "XXX", False), + (types.countries.Numeric, "004", True), + (types.countries.Numeric, "900", False), + (types.countries.Name, "Ukraine", True), + (types.countries.Name, "Wonderland", False), + (types.countries.Flag, "πŸ‡ΏπŸ‡Ό", True), + (types.countries.Flag, "πŸ€—", False), ], ) def test_type_enum(custom_type, test_string, should_match): - type_name = custom_type.__name__ class Model(BaseModel):