Skip to content

Commit

Permalink
Move some code from dbt-core to dbt-common.
Browse files Browse the repository at this point in the history
  • Loading branch information
peterallenwebb committed Dec 13, 2024
1 parent 243568e commit 3a20cdc
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 0 deletions.
2 changes: 2 additions & 0 deletions dbt_common/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
Optional,
Union,
Set,
Tuple,
Type,
NoReturn,
)

from hypothesis.errors import Frozen
from typing_extensions import Protocol

import jinja2
Expand Down
101 changes: 101 additions & 0 deletions dbt_common/clients/jinja_macro_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import dataclasses
from typing import Any, Dict, List, Optional

import jinja2

from dbt_common.clients.jinja import get_environment, MacroType

PRIMITIVE_TYPES = ["Any", "bool", "float", "int", "str"]


@dataclasses.dataclass
class TypeCheckFailure:
msg: str


@dataclasses.dataclass
class DbtMacroCall:
"""An instance of this class represents a jinja macro call in a template
for the purposes of recording information for type checking."""

name: str
source: str
arg_types: List[Optional[MacroType]] = dataclasses.field(default_factory=list)
kwarg_types: Dict[str, Optional[MacroType]] = dataclasses.field(default_factory=dict)

@classmethod
def from_call(cls, call: jinja2.nodes.Call, name: str) -> "DbtMacroCall":
dbt_call = cls(name, "")
for arg in call.args: # type: ignore
dbt_call.arg_types.append(cls.get_type(arg))
for arg in call.kwargs: # type: ignore
dbt_call.kwarg_types[arg.key] = cls.get_type(arg.value)
return dbt_call

@classmethod
def get_type(cls, param: Any) -> Optional[MacroType]:
if isinstance(param, jinja2.nodes.Name):
return None # TODO: infer types from variable names

if isinstance(param, jinja2.nodes.Call):
return None # TODO: infer types from function/macro calls

if isinstance(param, jinja2.nodes.Getattr):
return None # TODO: infer types from . operator

if isinstance(param, jinja2.nodes.Concat):
return None

if isinstance(param, jinja2.nodes.Const):
if isinstance(param.value, str): # type: ignore
return MacroType("str")
elif isinstance(param.value, bool): # type: ignore
return MacroType("bool")
elif isinstance(param.value, int): # type: ignore
return MacroType("int")
elif isinstance(param.value, float): # type: ignore
return MacroType("float")
elif param.value is None: # type: ignore
return None
else:
return None

if isinstance(param, jinja2.nodes.Dict):
return None

return None

def is_valid_type(self, t: MacroType) -> bool:
if len(t.type_params) == 0 and t.name in PRIMITIVE_TYPES:
return True
elif (
t.name == "Dict"
and len(t.type_params) == 2
and t.type_params[0].name in PRIMITIVE_TYPES
and self.is_valid_type(t.type_params[1])
):
return True
elif (
t.name in ["List", "Optional"]
and len(t.type_params) == 1
and self.is_valid_type(t.type_params[0])
):
return True

return False

def check(self, macro_text: str) -> List[TypeCheckFailure]:
failures: List[TypeCheckFailure] = []
template = get_environment(None, capture_macros=True).parse(macro_text)
jinja_macro = template.body[0]

for arg_type in jinja_macro.arg_types:
if not self.is_valid_type(arg_type):
failures.append(TypeCheckFailure(msg="Invalid type."))

for i, arg_type in enumerate(self.arg_types):
expected_type = jinja_macro.arg_types[i]
if arg_type != expected_type:
failures.append(TypeCheckFailure(msg="Wrong type of parameter."))

return failures
57 changes: 57 additions & 0 deletions tests/unit/test_jinja_macro_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from dbt_common.clients.jinja import MacroType
from dbt_common.clients.jinja_macro_call import PRIMITIVE_TYPES, DbtMacroCall


single_param_macro_text = """{% macro call_me(param: TYPE) %}
{% endmacro %}"""


def test_primitive_type_checks() -> None:
for type_name in PRIMITIVE_TYPES:
macro_text = single_param_macro_text.replace("TYPE", type_name)
call = DbtMacroCall("call_me", "call_me", [MacroType(type_name, [])], {})
assert not any(call.check(macro_text))


def test_primitive_type_checks_wrong() -> None:
for type_name in PRIMITIVE_TYPES:
macro_text = single_param_macro_text.replace("TYPE", type_name)
wrong_type = next(t for t in PRIMITIVE_TYPES if t != type_name)
call = DbtMacroCall("call_me", "call_me", [MacroType(wrong_type, [])], {})
assert any(call.check(macro_text))


def test_list_type_checks() -> None:
for type_name in PRIMITIVE_TYPES:
macro_text = single_param_macro_text.replace("TYPE", f"List[{type_name}]")
expected_type = MacroType("List", [MacroType(type_name)])
call = DbtMacroCall("call_me", "call_me", [expected_type], {})
assert not any(call.check(macro_text))


def test_dict_type_checks() -> None:
for type_name in PRIMITIVE_TYPES:
macro_text = single_param_macro_text.replace("TYPE", f"Dict[{type_name}, {type_name}]")
expected_type = MacroType("Dict", [MacroType(type_name), MacroType(type_name)])
call = DbtMacroCall("call_me", "call_me", [expected_type], {})
assert not any(call.check(macro_text))


def test_too_few_args() -> None:
macro_text = "{% macro call_me(one: str, two: str, three: str) %}"


def test_too_many_args() -> None:
pass


kwarg_param_macro_text = """{% macro call_me(param: int = 10, arg_one = "val1", arg_two: int = 2, arg_three: str = "val3" ) %}
{% endmacro %}"""


# Better structured exceptions
# Test detection of macro called with too few positional args
# Test detection of macro called with too many positional args
# Test detection of macro called with keyword arg having wrong type
# Test detection of macro called with non-existent keyword arg
# Test detection of macro with invalid default value for param type

0 comments on commit 3a20cdc

Please sign in to comment.