Skip to content

Commit

Permalink
Remove match case syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Sep 24, 2023
1 parent 52b105e commit c5e97a0
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 26 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@ jobs:
tests:
name: Run the tests
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.10"]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: ${{ matrix.python-version }}
- name: Set up test environment
run: |
python -m pip install --upgrade pip
Expand Down
49 changes: 30 additions & 19 deletions outlines/text/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,39 +202,50 @@ def match_step_to_regex(step):
schedule's step.
"""
match step:
case str() as step:
return step
if isinstance(step, str):
return step

case {"enum": choices, "type": "string"}:
choices = [f'"{re.escape(choice)}"' for choice in choices]
if isinstance(step, dict):
keys = set(step.keys())

if all(key in keys for key in ("enum", "type")) and step["type"] == "string":
choices = [f'"{re.escape(choice)}"' for choice in step["enum"]]
return f"({'|'.join(choices)})"
case {"enum": choices}:
choices = [re.escape(str(choice)) for choice in choices]

elif "enum" in keys:
choices = [re.escape(str(choice)) for choice in step["enum"]]
return f"({'|'.join(choices)})"

case {"type": "array", "items": items}:
item_regexes = match_step_to_regex(items)
elif all(key in keys for key in ("type", "items")) and step["type"] == "array":
item_regexes = match_step_to_regex(step["items"])
return rf"\[({item_regexes})(,({item_regexes}))*\]"

case {"type": "object"} as object:
steps = build_schedule_from_schema(json.dumps(object))
elif "type" in keys and step["type"] == "object":
steps = build_schedule_from_schema(json.dumps(step))
regex_str = ""
for step in steps:
regex_str += match_step_to_regex(step)
return regex_str

case {"type": "string", "maxLength": max_length}:
elif (
all(key in keys for key in ("type", "maxLength"))
and step["type"] == "string"
):
max_length = step["maxLength"]
return f'"{STRING_INNER}{{,{max_length}}}"'
case {"type": "string", "minLength": min_length}:

elif (
all(key in keys for key in ("type", "minLength"))
and step["type"] == "string"
):
min_length = step["minLength"]
return f'"{STRING_INNER}{{{min_length},}}"'

case {"type": field_type}:
return type_to_regex[field_type]
elif "type" in keys:
return type_to_regex[step["type"]]

case {"anyOf": choices}:
regexes = [match_step_to_regex(choice) for choice in choices]
elif "anyOf" in keys:
regexes = [match_step_to_regex(choice) for choice in step["anyOf"]]
return rf"({'|'.join(regexes)})"

case _:
raise NotImplementedError
raise NotImplementedError
4 changes: 2 additions & 2 deletions outlines/text/parsing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import ChainMap
from copy import copy, deepcopy
from dataclasses import dataclass
from functools import cache
from functools import lru_cache
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -576,7 +576,7 @@ def parse_from_state(self, state, last_token=None, is_end=False):

class PartialScanner(Scanner):
@classmethod
@cache
@lru_cache
def construct_terminal_fsm(cls, terminal):
# TODO: This should really be done at the lexer/parser level so that
# the lifetime of these objects is tied to the parser itself.
Expand Down
4 changes: 2 additions & 2 deletions outlines/text/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
import textwrap
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, cast
from typing import Any, Callable, Dict, List, Optional, Type, cast

from jinja2 import Environment, StrictUndefined
from pydantic import BaseModel
Expand Down Expand Up @@ -284,7 +284,7 @@ def get_schema_dict(model: Dict):


@get_schema.register(type(BaseModel))
def get_schema_pydantic(model: type[BaseModel]):
def get_schema_pydantic(model: Type[BaseModel]):
"""Return the schema of a Pydantic model."""
if not type(model) == type(BaseModel):
raise TypeError("The `schema` filter only applies to Pydantic models.")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
name = "outlines"
authors= [{name = "Outlines Developers"}]
description = "Probabilistic Generative Model Programming"
requires-python = ">=3.10"
requires-python = ">=3.8"
keywords=[
"machine learning",
"deep learning",
Expand Down

0 comments on commit c5e97a0

Please sign in to comment.