Skip to content

Commit

Permalink
Merge branch 'main' into setup-passed-password
Browse files Browse the repository at this point in the history
  • Loading branch information
zsimjee committed Nov 3, 2023
2 parents e3dad5f + 6267a3d commit 49ee2bd
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 20 deletions.
9 changes: 7 additions & 2 deletions .github/workflows/examples_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,23 @@ jobs:
execute_notebooks:
runs-on: ubuntu-latest

env:
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}


steps:
- name: Checkout repository
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.x
python-version: 3.11.x

- name: Install dependencies
run: |
pip install jupyter nbconvert
make full; pip install jupyter nbconvert; pip install .
- name: Execute notebooks and check for errors
run: |
Expand Down
7 changes: 1 addition & 6 deletions docs/examples/generate_structured_data_cohere.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,7 @@
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/zaydsimjee/workspace/shreya-gr/guardrails/.venv/lib/python3.11/site-packages/guardrails/schema.py:228: UserWarning: Validator 1-indexed is not valid for element integer.\n",
" warnings.warn(\n",
"/Users/zaydsimjee/workspace/shreya-gr/guardrails/.venv/lib/python3.11/site-packages/guardrails/prompt/prompt.py:23: UserWarning: Prompt does not have any variables, if you are migrating follow the new variable convention documented here: https://docs.getguardrails.ai/0-2-migration/\n",
" warnings.warn(\n"
]
"text": []
}
],
"source": [
Expand Down
5 changes: 4 additions & 1 deletion docs/examples/provenance.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,12 @@
"from guardrails import Guard\n",
"from guardrails.validators import ProvenanceV0\n",
"from typing import List, Union\n",
"import os\n",
"\n",
"api_key = os.environ[\"COHERE_API_KEY\"]\n",
"\n",
"# Create a cohere client\n",
"cohere_client = cohere.Client(api_key=\"<Cohere_API_KEY>\")\n",
"cohere_client = cohere.Client(api_key=api_key)\n",
"\n",
"\n",
"def embed_function(text: Union[str, List[str]]) -> np.ndarray:\n",
Expand Down
8 changes: 4 additions & 4 deletions guardrails/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def parse(
llm_output: str,
metadata: Optional[Dict] = None,
llm_api: None = None,
num_reasks: int = 1,
num_reasks: Optional[int] = None,
prompt_params: Optional[Dict] = None,
full_schema_reask: Optional[bool] = None,
*args,
Expand All @@ -458,7 +458,7 @@ def parse(
llm_output: str,
metadata: Optional[Dict] = None,
llm_api: Callable[[Any], Awaitable[Any]] = ...,
num_reasks: int = 1,
num_reasks: Optional[int] = None,
prompt_params: Optional[Dict] = None,
full_schema_reask: Optional[bool] = None,
*args,
Expand All @@ -472,7 +472,7 @@ def parse(
llm_output: str,
metadata: Optional[Dict] = None,
llm_api: Optional[Callable] = None,
num_reasks: int = 1,
num_reasks: Optional[int] = None,
prompt_params: Optional[Dict] = None,
full_schema_reask: Optional[bool] = None,
*args,
Expand All @@ -485,7 +485,7 @@ def parse(
llm_output: str,
metadata: Optional[Dict] = None,
llm_api: Optional[Callable] = None,
num_reasks: int = 1,
num_reasks: Optional[int] = None,
prompt_params: Optional[Dict] = None,
full_schema_reask: Optional[bool] = None,
*args,
Expand Down
12 changes: 8 additions & 4 deletions guardrails/validator_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from collections import defaultdict
from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union

Expand Down Expand Up @@ -245,11 +246,14 @@ def to_xml_attrib(self):
return self.rail_alias

validator_args = []
for arg in self.__init__.__code__.co_varnames[1:]:
init_args = inspect.getfullargspec(self.__init__)
for arg in init_args.args[1:]:
if arg not in ("on_fail", "args", "kwargs"):
str_arg = str(self._kwargs[arg])
str_arg = "{" + str_arg + "}" if " " in str_arg else str_arg
validator_args.append(str_arg)
arg_value = self._kwargs.get(arg)
str_arg = str(arg_value)
if str_arg is not None:
str_arg = "{" + str_arg + "}" if " " in str_arg else str_arg
validator_args.append(str_arg)

params = " ".join(validator_args)
return f"{self.rail_alias}: {params}"
Expand Down
5 changes: 3 additions & 2 deletions guardrails/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import rstr
from tenacity import retry, stop_after_attempt, wait_random_exponential

from guardrails.utils.casting_utils import to_int
from guardrails.utils.docs_utils import get_chunks_from_text, sentence_split
from guardrails.utils.sql_utils import SQLDriver, create_sql_driver
from guardrails.utils.validator_utils import PROVENANCE_V1_PROMPT
Expand Down Expand Up @@ -277,8 +278,8 @@ def __init__(
on_fail: Optional[Callable] = None,
):
super().__init__(on_fail=on_fail, min=min, max=max)
self._min = int(min) if min is not None else None
self._max = int(max) if max is not None else None
self._min = to_int(min)
self._max = to_int(max)

def validate(self, value: Union[str, List], metadata: Dict) -> ValidationResult:
"""Validates that the length of value is within the expected range."""
Expand Down
2 changes: 1 addition & 1 deletion guardrails/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.5"
__version__ = "0.2.6"

0 comments on commit 49ee2bd

Please sign in to comment.