diff --git a/dbt_common/clients/jinja.py b/dbt_common/clients/jinja.py index 4bba253..44c87dd 100644 --- a/dbt_common/clients/jinja.py +++ b/dbt_common/clients/jinja.py @@ -18,10 +18,12 @@ Optional, Union, Set, + Tuple, Type, NoReturn, ) +from hypothesis.errors import Frozen from typing_extensions import Protocol import jinja2 diff --git a/dbt_common/clients/jinja_macro_call.py b/dbt_common/clients/jinja_macro_call.py new file mode 100644 index 0000000..188e1c6 --- /dev/null +++ b/dbt_common/clients/jinja_macro_call.py @@ -0,0 +1,101 @@ +import dataclasses +from typing import Any, Dict, List, Optional + +import jinja2 + +from dbt_common.clients.jinja import get_environment, MacroType + +PRIMITIVE_TYPES = ["Any", "bool", "float", "int", "str"] + + +@dataclasses.dataclass +class TypeCheckFailure: + msg: str + + +@dataclasses.dataclass +class DbtMacroCall: + """An instance of this class represents a jinja macro call in a template + for the purposes of recording information for type checking.""" + + name: str + source: str + arg_types: List[Optional[MacroType]] = dataclasses.field(default_factory=list) + kwarg_types: Dict[str, Optional[MacroType]] = dataclasses.field(default_factory=dict) + + @classmethod + def from_call(cls, call: jinja2.nodes.Call, name: str) -> "DbtMacroCall": + dbt_call = cls(name, "") + for arg in call.args: # type: ignore + dbt_call.arg_types.append(cls.get_type(arg)) + for arg in call.kwargs: # type: ignore + dbt_call.kwarg_types[arg.key] = cls.get_type(arg.value) + return dbt_call + + @classmethod + def get_type(cls, param: Any) -> Optional[MacroType]: + if isinstance(param, jinja2.nodes.Name): + return None # TODO: infer types from variable names + + if isinstance(param, jinja2.nodes.Call): + return None # TODO: infer types from function/macro calls + + if isinstance(param, jinja2.nodes.Getattr): + return None # TODO: infer types from . operator + + if isinstance(param, jinja2.nodes.Concat): + return None + + if isinstance(param, jinja2.nodes.Const): + if isinstance(param.value, str): # type: ignore + return MacroType("str") + elif isinstance(param.value, bool): # type: ignore + return MacroType("bool") + elif isinstance(param.value, int): # type: ignore + return MacroType("int") + elif isinstance(param.value, float): # type: ignore + return MacroType("float") + elif param.value is None: # type: ignore + return None + else: + return None + + if isinstance(param, jinja2.nodes.Dict): + return None + + return None + + def is_valid_type(self, t: MacroType) -> bool: + if len(t.type_params) == 0 and t.name in PRIMITIVE_TYPES: + return True + elif ( + t.name == "Dict" + and len(t.type_params) == 2 + and t.type_params[0].name in PRIMITIVE_TYPES + and self.is_valid_type(t.type_params[1]) + ): + return True + elif ( + t.name in ["List", "Optional"] + and len(t.type_params) == 1 + and self.is_valid_type(t.type_params[0]) + ): + return True + + return False + + def check(self, macro_text: str) -> List[TypeCheckFailure]: + failures: List[TypeCheckFailure] = [] + template = get_environment(None, capture_macros=True).parse(macro_text) + jinja_macro = template.body[0] + + for arg_type in jinja_macro.arg_types: + if not self.is_valid_type(arg_type): + failures.append(TypeCheckFailure(msg="Invalid type.")) + + for i, arg_type in enumerate(self.arg_types): + expected_type = jinja_macro.arg_types[i] + if arg_type != expected_type: + failures.append(TypeCheckFailure(msg="Wrong type of parameter.")) + + return failures diff --git a/tests/unit/test_jinja_macro_call.py b/tests/unit/test_jinja_macro_call.py new file mode 100644 index 0000000..4b4301f --- /dev/null +++ b/tests/unit/test_jinja_macro_call.py @@ -0,0 +1,57 @@ +from dbt_common.clients.jinja import MacroType +from dbt_common.clients.jinja_macro_call import PRIMITIVE_TYPES, DbtMacroCall + + +single_param_macro_text = """{% macro call_me(param: TYPE) %} +{% endmacro %}""" + + +def test_primitive_type_checks() -> None: + for type_name in PRIMITIVE_TYPES: + macro_text = single_param_macro_text.replace("TYPE", type_name) + call = DbtMacroCall("call_me", "call_me", [MacroType(type_name, [])], {}) + assert not any(call.check(macro_text)) + + +def test_primitive_type_checks_wrong() -> None: + for type_name in PRIMITIVE_TYPES: + macro_text = single_param_macro_text.replace("TYPE", type_name) + wrong_type = next(t for t in PRIMITIVE_TYPES if t != type_name) + call = DbtMacroCall("call_me", "call_me", [MacroType(wrong_type, [])], {}) + assert any(call.check(macro_text)) + + +def test_list_type_checks() -> None: + for type_name in PRIMITIVE_TYPES: + macro_text = single_param_macro_text.replace("TYPE", f"List[{type_name}]") + expected_type = MacroType("List", [MacroType(type_name)]) + call = DbtMacroCall("call_me", "call_me", [expected_type], {}) + assert not any(call.check(macro_text)) + + +def test_dict_type_checks() -> None: + for type_name in PRIMITIVE_TYPES: + macro_text = single_param_macro_text.replace("TYPE", f"Dict[{type_name}, {type_name}]") + expected_type = MacroType("Dict", [MacroType(type_name), MacroType(type_name)]) + call = DbtMacroCall("call_me", "call_me", [expected_type], {}) + assert not any(call.check(macro_text)) + + +def test_too_few_args() -> None: + macro_text = "{% macro call_me(one: str, two: str, three: str) %}" + + +def test_too_many_args() -> None: + pass + + +kwarg_param_macro_text = """{% macro call_me(param: int = 10, arg_one = "val1", arg_two: int = 2, arg_three: str = "val3" ) %} +{% endmacro %}""" + + +# Better structured exceptions +# Test detection of macro called with too few positional args +# Test detection of macro called with too many positional args +# Test detection of macro called with keyword arg having wrong type +# Test detection of macro called with non-existent keyword arg +# Test detection of macro with invalid default value for param type