Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Commit

Permalink
Prototype for typed_tag. Solve
Browse files Browse the repository at this point in the history
#276

This is a little rough -- we need more testing. But it should work.
Followed the spec at it was fairly clean -- using typed dicts, and
validating on decoration. Also restricting to only primitives -- we can
change it later but I want to keep things a lot simpler for now.
  • Loading branch information
elijahbenizzy committed Jan 15, 2023
1 parent 0469db3 commit 8c48744
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 3 deletions.
83 changes: 80 additions & 3 deletions hamilton/function_modifiers/metadata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Any, Callable, Dict
import typing
from typing import Any, Callable, Dict, Type, TypedDict

import typing_inspect

from hamilton import node
from hamilton.function_modifiers import base
Expand Down Expand Up @@ -45,13 +48,17 @@ class tag(base.NodeDecorator):
"module",
] # Anything that starts with any of these is banned, the framework reserves the right to manage it

def __init__(self, **tags: str):
def __init__(self, *, __validate_tag_types: bool = True, **tags: str):
"""Constructor for adding tag annotations to a function.
:param tags: the keys are always going to be strings, so the type annotation here means the values are strings.
Implicitly this is `Dict[str, str]` but the PEP guideline is to only annotate it with `str`.
:param __validate_tag_types: If true, we validate the types of the tags. This is called by the framework, and
should not be called by users. If you want to have more than just str valued tags, consider using typed tags
as specified below.
"""
self.tags = tags
self.__validate_tag_types = __validate_tag_types

def decorate_node(self, node_: node.Node) -> node.Node:
"""Decorates the nodes produced by this with the specified tags
Expand Down Expand Up @@ -105,8 +112,11 @@ def validate(self, fn: Callable):
"""
bad_tags = set()
for key, value in self.tags.items():
if (not tag._key_allowed(key)) or (not tag._value_allowed(value)):
if not tag._key_allowed(key):
bad_tags.add((key, value))
if not tag._value_allowed(value) and not self.__validate_tag_types:
bad_tags.add((key, value))

if bad_tags:
bad_tags_formatted = ",".join([f"{key}={value}" for key, value in bad_tags])
raise base.InvalidDecoratorException(
Expand All @@ -132,3 +142,70 @@ def decorate_node(self, node_: node.Node) -> node.Node:
new_tags = node_.tags.copy()
new_tags.update(self.tag_mapping.get(node_.name, {}))
return tag(**new_tags).decorate_node(node_)


# class TypedTagSet(TypedDict):
# """A typed tag set is a dictionary of tags that are typed. We do additional validation on this
# to ensure that the right types are created and that that ri"""


def _type_allowed(type: Type[Type], allow_lists: bool = True) -> bool:
"""Validates that a type is allowed. We only allow primitive types and lists of primitive types"""
if type in [int, float, str, bool]:
return True
if allow_lists:
if typing_inspect.is_generic_type(type):
if typing_inspect.get_origin(type) == list:
return _type_allowed(typing_inspect.get_args(type)[0], allow_lists=False)
return False


def _validate_spec(typed_dict_class: Type[TypedDict]):
invalid_types = []
for key, value in typing.get_type_hints(typed_dict_class).items():
if not _type_allowed(value, allow_lists=True):
invalid_types.append((key, value))
if invalid_types:
invalid_types_formatted = ",".join([f"{key}={value}" for key, value in invalid_types])
raise base.InvalidDecoratorException(
f"The following key/value pairs are invalid as types: {invalid_types_formatted} "
"Types can be any primitive type or a list of a primitive type."
)


def _type_matches(value: Any, type_: Type[Type]):
if type_ in [int, float, str, bool]:
return isinstance(value, type_)
if typing_inspect.is_generic_type(type_):
if typing_inspect.get_origin(type_) == list:
return isinstance(value, list) and all(
_type_matches(item, typing_inspect.get_args(type_)[0]) for item in value
)
return False


def _validate_values(typed_dict: dict, typed_dict_class: Type[TypedDict]):
invalid_pairs = []
for key, value in typed_dict.items():
if not _type_matches(value, typing.get_type_hints(typed_dict_class)[key]):
invalid_pairs.append((key, value))
if invalid_pairs:
invalid_pairs_formatted = ",".join([f"{key}={value}" for key, value in invalid_pairs])
raise base.InvalidDecoratorException(
f"The following key/value pairs are invalid as values: {invalid_pairs_formatted} "
"Values must match the specified type."
)


def validate_typed_dict(data: dict, typed_dict_class: TypedDict):
_validate_spec(typed_dict_class)
_validate_values(data, typed_dict_class)


class typed_tags:
def __init__(self, typed_tag_class: TypedDict):
self.tag_set_type = typed_tag_class

def __call__(self, **kwargs: Any):
validate_typed_dict(dict(**kwargs), self.tag_set_type)
return tag(**kwargs, __validate_tag_types=False) # types are already validated
83 changes: 83 additions & 0 deletions tests/function_modifiers/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Dict, List, TypedDict

import pandas as pd
import pytest

from hamilton import function_modifiers, node
from hamilton.function_modifiers.base import InvalidDecoratorException
from hamilton.function_modifiers.metadata import typed_tags


def test_tags():
Expand Down Expand Up @@ -108,3 +112,82 @@ def dummy_tagged_function() -> pd.DataFrame:
assert node_map["b"].tags["tag_b_gets"] == "tag_value_b_gets"
assert node_map["a"].tags["tag_key_everyone_gets"] == "tag_value_just_a_gets"
assert node_map["b"].tags["tag_key_everyone_gets"] == "tag_value_everyone_gets"


def test_typed_tags_success():
"""Tests the typed_tags decorator to ensure that it works in the basic case"""

class FooType(TypedDict):
foo: str
bar: int

foo = typed_tags(FooType)

def dummy_tagged_function() -> int:
"""dummy doc"""
return 1

node_ = foo(foo="foo", bar=1).decorate_node(node.Node.from_fn(dummy_tagged_function))
assert node_.tags["foo"] == "foo"
assert node_.tags["bar"] == 1


def test_typed_tags_wrong_type_failure():
"""Tests the typed_tags decorator to ensure that it breaks when the wrong types are passed"""

class FooType(TypedDict):
foo: str
bar: int

foo = typed_tags(FooType)

with pytest.raises(InvalidDecoratorException):

@foo(foo=1, bar="bar")
def dummy_tagged_function() -> int:
"""dummy doc"""
return 1


def test_typed_tags_illegal_types_failure():
"""Tests the typed_tags decorator to ensure that it breaks when illegal types are declared"""

class FooType(TypedDict):
foo: Dict[str, dict]
bar: List[List[int]]

foo = typed_tags(FooType)

with pytest.raises(InvalidDecoratorException):

@foo(foo=1, bar="bar")
def dummy_tagged_function() -> int:
"""dummy doc"""
return 1


def test_layered_tags_success():
"""Tests to ensure that layered tags are applied appropriately"""

class FooType(TypedDict):
foo: str
bar: int

class BarType(TypedDict):
bat: int
baz: List[int]

foo = typed_tags(FooType)
bar = typed_tags(BarType)

@foo(foo="foo", bar=1)
@bar(bat=2, baz=[1, 2, 3])
def dummy_tagged_function() -> int:
"""dummy doc"""
return 1

(node_,) = function_modifiers.base.resolve_nodes(dummy_tagged_function, {})
assert node_.tags["foo"] == "foo"
assert node_.tags["bar"] == 1
assert node_.tags["bat"] == 2
assert node_.tags["baz"] == [1, 2, 3]

0 comments on commit 8c48744

Please sign in to comment.