Skip to content

Commit

Permalink
Merge pull request #517 from guardrails-ai/skeleton-reask-engineering
Browse files Browse the repository at this point in the history
Skeleton reask prompt engineering
  • Loading branch information
zsimjee authored Dec 19, 2023
2 parents 11442aa + b41de2e commit 991a8f2
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 5 deletions.
6 changes: 6 additions & 0 deletions guardrails/constants.xml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ ${output_schema}
ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`.
</json_suffix_without_examples>

<json_suffix_with_structure_example>
${gr.json_suffix_without_examples}
Here's an example of the structure:
${json_example}
</json_suffix_with_structure_example>

<complete_json_suffix>
Given below is XML that describes the information to extract from this document and the tags to extract it into.

Expand Down
56 changes: 56 additions & 0 deletions guardrails/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def __init__(
self.description = description
self.optional = optional

def get_example(self):
raise NotImplementedError

@property
def validators(self) -> TypedList:
return self.validators_attr.validators
Expand Down Expand Up @@ -188,6 +191,9 @@ class String(ScalarType):

tag = "string"

def get_example(self):
return "string"

def from_str(self, s: str) -> Optional[str]:
"""Create a String from a string."""
return to_string(s)
Expand All @@ -214,6 +220,9 @@ class Integer(ScalarType):

tag = "integer"

def get_example(self):
return 1

def from_str(self, s: str) -> Optional[int]:
"""Create an Integer from a string."""
return to_int(s)
Expand All @@ -225,6 +234,9 @@ class Float(ScalarType):

tag = "float"

def get_example(self):
return 1.5

def from_str(self, s: str) -> Optional[float]:
"""Create a Float from a string."""
return to_float(s)
Expand All @@ -236,6 +248,9 @@ class Boolean(ScalarType):

tag = "bool"

def get_example(self):
return True

def from_str(self, s: Union[str, bool]) -> Optional[bool]:
"""Create a Boolean from a string."""
if s is None:
Expand Down Expand Up @@ -273,6 +288,9 @@ def __init__(
super().__init__(children, validators_attr, optional, name, description)
self.date_format = None

def get_example(self):
return datetime.date.today()

def from_str(self, s: str) -> Optional[datetime.date]:
"""Create a Date from a string."""
if s is None:
Expand Down Expand Up @@ -312,6 +330,9 @@ def __init__(
self.time_format = "%H:%M:%S"
super().__init__(children, validators_attr, optional, name, description)

def get_example(self):
return datetime.time()

def from_str(self, s: str) -> Optional[datetime.time]:
"""Create a Time from a string."""
if s is None:
Expand Down Expand Up @@ -340,6 +361,9 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
deprecate_type(type(self))

def get_example(self):
return "[email protected]"


@deprecate_type
@register_type("url")
Expand All @@ -352,6 +376,9 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
deprecate_type(type(self))

def get_example(self):
return "https://example.com"


@deprecate_type
@register_type("pythoncode")
Expand All @@ -364,6 +391,9 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
deprecate_type(type(self))

def get_example(self):
return "print('hello world')"


@deprecate_type
@register_type("sql")
Expand All @@ -376,13 +406,19 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
deprecate_type(type(self))

def get_example(self):
return "SELECT * FROM table"


@register_type("percentage")
class Percentage(ScalarType):
"""Element tag: `<percentage>`"""

tag = "percentage"

def get_example(self):
return "20%"


@register_type("enum")
class Enum(ScalarType):
Expand All @@ -402,6 +438,9 @@ def __init__(
super().__init__(children, validators_attr, optional, name, description)
self.enum_values = enum_values

def get_example(self):
return self.enum_values[0]

def from_str(self, s: str) -> Optional[str]:
"""Create an Enum from a string."""
if s is None:
Expand Down Expand Up @@ -434,6 +473,9 @@ class List(NonScalarType):

tag = "list"

def get_example(self):
return [e.get_example() for e in self._children.values()]

def collect_validation(
self,
key: str,
Expand Down Expand Up @@ -476,6 +518,9 @@ class Object(NonScalarType):

tag = "object"

def get_example(self):
return {k: v.get_example() for k, v in self._children.items()}

def collect_validation(
self,
key: str,
Expand Down Expand Up @@ -546,6 +591,14 @@ def __init__(
super().__init__(children, validators_attr, optional, name, description)
self.discriminator_key = discriminator_key

def get_example(self):
first_discriminator = list(self._children.keys())[0]
first_child = list(self._children.values())[0]
return {
self.discriminator_key: first_discriminator,
**first_child.get_example(),
}

@classmethod
def from_xml(cls, element: ET._Element, strict: bool = False, **kwargs) -> Self:
# grab `discriminator` attribute
Expand Down Expand Up @@ -606,6 +659,9 @@ def __init__(
) -> None:
super().__init__(children, validators_attr, optional, name, description)

def get_example(self):
return {k: v.get_example() for k, v in self._children.items()}

def collect_validation(
self,
key: str,
Expand Down
9 changes: 7 additions & 2 deletions guardrails/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def check_valid_reask_prompt(self, reask_prompt: Optional[str]) -> None:


class JsonSchema(Schema):
reask_prompt_vars = {"previous_response", "output_schema"}
reask_prompt_vars = {"previous_response", "output_schema", "json_example"}

def __init__(
self,
Expand Down Expand Up @@ -269,7 +269,7 @@ def get_reask_setup(
if reask_prompt_template is None:
reask_prompt_template = Prompt(
constants["high_level_skeleton_reask_prompt"]
+ constants["json_suffix_without_examples"]
+ constants["json_suffix_with_structure_example"]
)

# This is incorrect
Expand Down Expand Up @@ -300,6 +300,10 @@ def get_reask_setup(
)

pruned_tree_string = pruned_tree_schema.transpile()
json_example = json.dumps(
pruned_tree_schema.root_datatype.get_example(),
indent=2,
)

def reask_decoder(obj):
decoded = {}
Expand All @@ -317,6 +321,7 @@ def reask_decoder(obj):
reask_value, indent=2, default=reask_decoder, ensure_ascii=False
),
output_schema=pruned_tree_string,
json_example=json_example,
**(prompt_params or {}),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ I was given the following JSON response, which had problems due to incorrect val

Help me correct the incorrect values based on the given error messages.


Given below is XML that describes the information to extract from this document and the tags to extract it into.

<output>
Expand All @@ -85,6 +86,18 @@ Given below is XML that describes the information to extract from this document

ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`.

Here's an example of the structure:
{
"fees": [
{
"name": "string",
"explanation": "string",
"value": 1.5
}
],
"interest_rates": {}
}


Json Output:

Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ I was given the following JSON response, which had problems due to incorrect val

Help me correct the incorrect values based on the given error messages.


Given below is XML that describes the information to extract from this document and the tags to extract it into.

<output>
Expand All @@ -23,3 +24,10 @@ Given below is XML that describes the information to extract from this document


ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`.

Here's an example of the structure:
{
"name": "string",
"director": "string",
"release_year": 1
}
11 changes: 8 additions & 3 deletions tests/unit_tests/utils/test_reask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pytest
from lxml import etree as ET

from guardrails import Instructions, Prompt
from guardrails.classes.history.iteration import Iteration
from guardrails.datatypes import Object
from guardrails.schema import JsonSchema
Expand Down Expand Up @@ -443,10 +442,14 @@ def test_get_reask_prompt(
Help me correct the incorrect values based on the given error messages.
Given below is XML that describes the information to extract from this document and the tags to extract it into.
%s
ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`.
Here's an example of the structure:
%s
""" # noqa: E501
expected_instructions = """
You are a helpful assistant only capable of communicating with valid JSON, and no other text.
Expand All @@ -467,13 +470,15 @@ def test_get_reask_prompt(
result_prompt,
instructions,
) = output_schema.get_reask_setup(reasks, reask_json, False)
json_example = output_schema.root_datatype.get_example()

assert result_prompt == Prompt(
assert result_prompt.source == (
expected_result_template
% (
json.dumps(reask_json, indent=2),
expected_rail,
json.dumps(json_example, indent=2),
)
)

assert instructions == Instructions(expected_instructions)
assert instructions.source == expected_instructions

0 comments on commit 991a8f2

Please sign in to comment.