diff --git a/outlines/generate/choice.py b/outlines/generate/choice.py index 72b5e3efd..92213c3df 100644 --- a/outlines/generate/choice.py +++ b/outlines/generate/choice.py @@ -1,4 +1,5 @@ import json as pyjson +import re from enum import Enum from functools import singledispatch from typing import Callable, List, Union @@ -19,6 +20,7 @@ def choice( if isinstance(choices, type(Enum)): regex_str = build_regex_from_schema(pyjson.dumps(get_schema_from_enum(choices))) else: + choices = [re.escape(choice) for choice in choices] # type: ignore regex_str = r"(" + r"|".join(choices) + r")" generator = regex(model, regex_str, sampler)