diff --git a/hamilton/function_modifiers/metadata.py b/hamilton/function_modifiers/metadata.py index afdff918..472a00a9 100644 --- a/hamilton/function_modifiers/metadata.py +++ b/hamilton/function_modifiers/metadata.py @@ -1,4 +1,7 @@ -from typing import Any, Callable, Dict +import typing +from typing import Any, Callable, Dict, Type, TypedDict + +import typing_inspect from hamilton import node from hamilton.function_modifiers import base @@ -45,13 +48,17 @@ class tag(base.NodeDecorator): "module", ] # Anything that starts with any of these is banned, the framework reserves the right to manage it - def __init__(self, **tags: str): + def __init__(self, *, __validate_tag_types: bool = True, **tags: str): """Constructor for adding tag annotations to a function. :param tags: the keys are always going to be strings, so the type annotation here means the values are strings. Implicitly this is `Dict[str, str]` but the PEP guideline is to only annotate it with `str`. + :param __validate_tag_types: If true, we validate the types of the tags. This is called by the framework, and + should not be called by users. If you want to have more than just str valued tags, consider using typed tags + as specified below. """ self.tags = tags + self.__validate_tag_types = __validate_tag_types def decorate_node(self, node_: node.Node) -> node.Node: """Decorates the nodes produced by this with the specified tags @@ -105,8 +112,11 @@ def validate(self, fn: Callable): """ bad_tags = set() for key, value in self.tags.items(): - if (not tag._key_allowed(key)) or (not tag._value_allowed(value)): + if not tag._key_allowed(key): + bad_tags.add((key, value)) + if not tag._value_allowed(value) and not self.__validate_tag_types: bad_tags.add((key, value)) + if bad_tags: bad_tags_formatted = ",".join([f"{key}={value}" for key, value in bad_tags]) raise base.InvalidDecoratorException( @@ -132,3 +142,70 @@ def decorate_node(self, node_: node.Node) -> node.Node: new_tags = node_.tags.copy() new_tags.update(self.tag_mapping.get(node_.name, {})) return tag(**new_tags).decorate_node(node_) + + +# class TypedTagSet(TypedDict): +# """A typed tag set is a dictionary of tags that are typed. We do additional validation on this +# to ensure that the right types are created and that that ri""" + + +def _type_allowed(type: Type[Type], allow_lists: bool = True) -> bool: + """Validates that a type is allowed. We only allow primitive types and lists of primitive types""" + if type in [int, float, str, bool]: + return True + if allow_lists: + if typing_inspect.is_generic_type(type): + if typing_inspect.get_origin(type) == list: + return _type_allowed(typing_inspect.get_args(type)[0], allow_lists=False) + return False + + +def _validate_spec(typed_dict_class: Type[TypedDict]): + invalid_types = [] + for key, value in typing.get_type_hints(typed_dict_class).items(): + if not _type_allowed(value, allow_lists=True): + invalid_types.append((key, value)) + if invalid_types: + invalid_types_formatted = ",".join([f"{key}={value}" for key, value in invalid_types]) + raise base.InvalidDecoratorException( + f"The following key/value pairs are invalid as types: {invalid_types_formatted} " + "Types can be any primitive type or a list of a primitive type." + ) + + +def _type_matches(value: Any, type_: Type[Type]): + if type_ in [int, float, str, bool]: + return isinstance(value, type_) + if typing_inspect.is_generic_type(type_): + if typing_inspect.get_origin(type_) == list: + return isinstance(value, list) and all( + _type_matches(item, typing_inspect.get_args(type_)[0]) for item in value + ) + return False + + +def _validate_values(typed_dict: dict, typed_dict_class: Type[TypedDict]): + invalid_pairs = [] + for key, value in typed_dict.items(): + if not _type_matches(value, typing.get_type_hints(typed_dict_class)[key]): + invalid_pairs.append((key, value)) + if invalid_pairs: + invalid_pairs_formatted = ",".join([f"{key}={value}" for key, value in invalid_pairs]) + raise base.InvalidDecoratorException( + f"The following key/value pairs are invalid as values: {invalid_pairs_formatted} " + "Values must match the specified type." + ) + + +def validate_typed_dict(data: dict, typed_dict_class: TypedDict): + _validate_spec(typed_dict_class) + _validate_values(data, typed_dict_class) + + +class typed_tags: + def __init__(self, typed_tag_class: TypedDict): + self.tag_set_type = typed_tag_class + + def __call__(self, **kwargs: Any): + validate_typed_dict(dict(**kwargs), self.tag_set_type) + return tag(**kwargs, __validate_tag_types=False) # types are already validated diff --git a/tests/function_modifiers/test_metadata.py b/tests/function_modifiers/test_metadata.py index 7072210f..c0a1f38b 100644 --- a/tests/function_modifiers/test_metadata.py +++ b/tests/function_modifiers/test_metadata.py @@ -1,7 +1,11 @@ +from typing import Dict, List, TypedDict + import pandas as pd import pytest from hamilton import function_modifiers, node +from hamilton.function_modifiers.base import InvalidDecoratorException +from hamilton.function_modifiers.metadata import typed_tags def test_tags(): @@ -108,3 +112,82 @@ def dummy_tagged_function() -> pd.DataFrame: assert node_map["b"].tags["tag_b_gets"] == "tag_value_b_gets" assert node_map["a"].tags["tag_key_everyone_gets"] == "tag_value_just_a_gets" assert node_map["b"].tags["tag_key_everyone_gets"] == "tag_value_everyone_gets" + + +def test_typed_tags_success(): + """Tests the typed_tags decorator to ensure that it works in the basic case""" + + class FooType(TypedDict): + foo: str + bar: int + + foo = typed_tags(FooType) + + def dummy_tagged_function() -> int: + """dummy doc""" + return 1 + + node_ = foo(foo="foo", bar=1).decorate_node(node.Node.from_fn(dummy_tagged_function)) + assert node_.tags["foo"] == "foo" + assert node_.tags["bar"] == 1 + + +def test_typed_tags_wrong_type_failure(): + """Tests the typed_tags decorator to ensure that it breaks when the wrong types are passed""" + + class FooType(TypedDict): + foo: str + bar: int + + foo = typed_tags(FooType) + + with pytest.raises(InvalidDecoratorException): + + @foo(foo=1, bar="bar") + def dummy_tagged_function() -> int: + """dummy doc""" + return 1 + + +def test_typed_tags_illegal_types_failure(): + """Tests the typed_tags decorator to ensure that it breaks when illegal types are declared""" + + class FooType(TypedDict): + foo: Dict[str, dict] + bar: List[List[int]] + + foo = typed_tags(FooType) + + with pytest.raises(InvalidDecoratorException): + + @foo(foo=1, bar="bar") + def dummy_tagged_function() -> int: + """dummy doc""" + return 1 + + +def test_layered_tags_success(): + """Tests to ensure that layered tags are applied appropriately""" + + class FooType(TypedDict): + foo: str + bar: int + + class BarType(TypedDict): + bat: int + baz: List[int] + + foo = typed_tags(FooType) + bar = typed_tags(BarType) + + @foo(foo="foo", bar=1) + @bar(bat=2, baz=[1, 2, 3]) + def dummy_tagged_function() -> int: + """dummy doc""" + return 1 + + (node_,) = function_modifiers.base.resolve_nodes(dummy_tagged_function, {}) + assert node_.tags["foo"] == "foo" + assert node_.tags["bar"] == 1 + assert node_.tags["bat"] == 2 + assert node_.tags["baz"] == [1, 2, 3]