Skip to content

Commit

Permalink
fix(args): apply default type validation when prompting
Browse files Browse the repository at this point in the history
  • Loading branch information
lt-mayonesa committed Apr 13, 2023
1 parent 9607636 commit f97911f
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 8 deletions.
2 changes: 1 addition & 1 deletion hexagon/domain/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def prompt(self, field: Union[ModelField, str], **kwargs):
)
if not model_field:
raise Exception(
f"argument field must be a field name or a ModelField instance, got {field}"
f"field [{field}] not found, must be a field name or a ModelField instance"
)

value_ = self.__prompt__.query_field(
Expand Down
49 changes: 42 additions & 7 deletions hexagon/support/prompt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from pathlib import Path
from typing import Callable, Any

from InquirerPy import inquirer
from InquirerPy.base import Choice
from prompt_toolkit.document import Document
from prompt_toolkit.validation import ValidationError, Validator
from pydantic.fields import ModelField
from pydantic import ValidationError as PydanticValidationError
from pydantic.fields import ModelField, Validator as PydanticValidator

from hexagon.domain.args import HexagonArg
from hexagon.support.printer import log
Expand All @@ -17,14 +19,32 @@ def __init__(self, validators, cls):
self.cls = cls

def validate(self, document: Document) -> None:
# FIXME: add validation based on field type's default validation
try:
for validator in self.validators.values():
validator.func(self.cls, document.text)
except PydanticValidationError as e:
raise ValidationError(
message=" / ".join([x["msg"] for x in e.errors()]),
cursor_position=len(document.text),
)
except ValueError as e:
raise ValidationError(message=e.args[0], cursor_position=len(document.text))


def default_validator(model_field: ModelField, mapper=lambda x: x):
def func(cls, value):
value, error = model_field.sub_fields[0].validate(mapper(value), {}, loc="")
if error:
raise PydanticValidationError([error], model=cls)
return value

return func


def list_mapper(v):
return v.strip().split("\n")


class Prompt:
def query_field(self, model_field: ModelField, model_class, **kwargs):
inq = self.text
Expand All @@ -41,10 +61,7 @@ def query_field(self, model_field: ModelField, model_class, **kwargs):

type_, iterable, of_enum = field_info(model_field)

if model_field.class_validators:
args["validate"] = PromptValidator(
model_field.class_validators, model_class
)
mapper: Callable[[Any], Any] = lambda x: x

if iterable and of_enum:
args["choices"] = [
Expand All @@ -57,7 +74,8 @@ def query_field(self, model_field: ModelField, model_class, **kwargs):
]
inq = self.checkbox
elif iterable:
args["filter"] = lambda x: x.strip().split("\n")
mapper = list_mapper
args["filter"] = mapper
args["message"] = args["message"] + " (each line represents a value)"
args["multiline"] = True
inq = self.text
Expand All @@ -72,6 +90,23 @@ def query_field(self, model_field: ModelField, model_class, **kwargs):
args["message"] = args["message"] + " (relative to project root)"
inq = self.path

if model_field.sub_fields:
validators_ = {
"default": PydanticValidator(
default_validator(model_field, mapper=mapper), check_fields=True
)
}
if model_field.class_validators:
validators_.update(
{
**model_field.class_validators,
}
)
args["validate"] = PromptValidator(
validators_,
model_class,
)

args.update(**kwargs)
return inq(**args)

Expand Down
28 changes: 28 additions & 0 deletions tests_e2e/__specs/execute_tool_with_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,3 +456,31 @@ def test_prompt_support_list_of_enum_arguments():
)
.exit()
)


def test_should_validate_type():
(
as_a_user(__file__)
.run_hexagon(
["prompt", "prompt_validate_type"],
os_env_vars={"HEXAGON_THEME": "default"},
)
.input("asdf")
.erase()
.input("hello world")
.erase()
.input("*()&UAS*(")
.erase()
.input("23.34")
.then_output_should_be(["total_amount: 23.34"], discard_until_first_match=True)
.then_output_should_be(
[
"To run this tool again do:",
"hexagon-test prompt prompt_validate_type --total-amount=23.34",
"or:",
"hexagon-test p prompt_validate_type -ta=23.34",
],
discard_until_first_match=True,
)
.exit()
)
4 changes: 4 additions & 0 deletions tests_e2e/execute_tool_with_args/python_module_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Args(ToolArgs):
likes: OptionalArg[list] = None
tag: OptionalArg[Category] = Category.C
available_tags: OptionalArg[List[Category]] = [Category.B, Category.E]
total_amount: OptionalArg[float] = None

@validator("age")
def validate_age(cls, arg):
Expand Down Expand Up @@ -70,3 +71,6 @@ def main(
log.result(f"tag type: {type(cli_args.tag.value).__name__}")
elif cli_args.test.value == "prompt_list_enum_choices":
log.result(f"available_tags: {cli_args.prompt('available_tags')}")
elif cli_args.test.value == "prompt_validate_type":
log.result(f"total_amount: {cli_args.prompt('total_amount')}")
log.result(f"total_amount type: {type(cli_args.total_amount.value).__name__}")

0 comments on commit f97911f

Please sign in to comment.