Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
peterallenwebb committed Jan 14, 2025
1 parent de91f53 commit 7b081b3
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 16 deletions.
74 changes: 61 additions & 13 deletions dbt_common/clients/jinja_macro_call.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
from enum import Enum
from typing import Any, Dict, List, Optional, Iterable
from typing import Any, Dict, List, Optional

import jinja2
import jinja2.nodes
Expand All @@ -18,11 +18,13 @@ class FailureType(Enum):
EXTRA_ARGUMENT = "extra_arg"
MISSING_ARGUMENT = "missing_arg"


@dataclasses.dataclass
class TypeCheckFailure:
type: FailureType
msg: str


@dataclasses.dataclass
class MacroCallChecker:
"""An instance of this class represents a jinja macro call in a template
Expand Down Expand Up @@ -56,31 +58,61 @@ def check(self, macro_text: str) -> List[TypeCheckFailure]:
target_type = macro_checker.arg_types[i]
unassigned_args.remove(target_name)
if arg_type is not None and target_type is not None and arg_type != target_type:
failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, f"Expected type {target_type.name} for argument {target_name} but found {arg_type.name}/"))
failures.append(
TypeCheckFailure(
FailureType.TYPE_MISMATCH,
f"Expected type {target_type.name} for argument {target_name} but found {arg_type.name}/",
)
)

# Each keyword argument in this call should correspond to an expected
# argument that has not already been assigned, and have a compatible type.
for arg_name, arg_type in self.kwarg_types.items():
if arg_name not in macro_checker.args:
failures.append(TypeCheckFailure(FailureType.EXTRA_ARGUMENT, f"Unexpected keyword argument {arg_name}."))
failures.append(
TypeCheckFailure(
FailureType.EXTRA_ARGUMENT, f"Unexpected keyword argument {arg_name}."
)
)
elif arg_name not in unassigned_args:
failures.append(TypeCheckFailure(FailureType.EXTRA_ARGUMENT, f"Argument {arg_name} was specified more than once."))
failures.append(
TypeCheckFailure(
FailureType.EXTRA_ARGUMENT,
f"Argument {arg_name} was specified more than once.",
)
)
else:
unassigned_args.remove(arg_name)
expected_type = macro_checker.get_arg_type(arg_name)
if arg_type is not None and expected_type is not None and arg_type != expected_type:
failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, f"Expected type {expected_type.name} for argument {arg_name} but found {arg_type.name}/"))
if (
arg_type is not None
and expected_type is not None
and arg_type != expected_type
):
failures.append(
TypeCheckFailure(
FailureType.TYPE_MISMATCH,
f"Expected type {expected_type.name} for argument {arg_name} but found {arg_type.name}/",
)
)

# Any remaining unassigned parameters must have a default.
for arg_name in unassigned_args:
if not macro_checker.has_default(arg_name):
failures.append(TypeCheckFailure(FailureType.MISSING_ARGUMENT, f"Missing argument {arg_name}."))
failures.append(
TypeCheckFailure(FailureType.MISSING_ARGUMENT, f"Missing argument {arg_name}.")
)

# Check that any arguments specified by keyword have the correct type
for arg_name, arg_type in self.kwarg_types.items():
expected_type = macro_checker.get_arg_type(arg_name)
if arg_type is not None and expected_type is not None and arg_type != expected_type:
failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, f"Expected type {expected_type.name} as argument {arg_name} but found {arg_type.name}/"))
failures.append(
TypeCheckFailure(
FailureType.TYPE_MISMATCH,
f"Expected type {expected_type.name} as argument {arg_name} but found {arg_type.name}/",
)
)

return failures

Expand Down Expand Up @@ -143,23 +175,39 @@ def check(t: Optional[MacroType]) -> List[TypeCheckFailure]:
failures: List[TypeCheckFailure] = []
if t.name == "Dict":
if len(t.type_params) != 2:
failures.append(TypeCheckFailure(FailureType.PARAMETER_COUNT, f"Expected two type parameters for Dict[], found {len(t.type_params)}."))
failures.append(
TypeCheckFailure(
FailureType.PARAMETER_COUNT,
f"Expected two type parameters for Dict[], found {len(t.type_params)}.",
)
)
else:
if t.type_params[0].name not in PRIMITIVE_TYPES:
failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, "First type parameter of Dict[] must be a primitive type."))
failures.append(
TypeCheckFailure(
FailureType.TYPE_MISMATCH,
"First type parameter of Dict[] must be a primitive type.",
)
)

failures.extend(TypeChecker.check(t.type_params[1]))
elif t.name in ("List", "Optional"):
if len(t.type_params) != 1:
failures.append(TypeCheckFailure(FailureType.PARAMETER_COUNT, f"Expected one type parameter for {t.name}[], found {len(t.type_params)}."))
failures.append(
TypeCheckFailure(
FailureType.PARAMETER_COUNT,
f"Expected one type parameter for {t.name}[], found {len(t.type_params)}.",
)
)

failures.extend(TypeChecker.check(t.type_params[0]))
else:
failures.append(TypeCheckFailure(FailureType.UNKNOWN_TYPE, f"Unknown type {t.name} encountered."))
failures.append(
TypeCheckFailure(FailureType.UNKNOWN_TYPE, f"Unknown type {t.name} encountered.")
)

return failures


@staticmethod
def get_type(param: Any) -> Optional[MacroType]:
if isinstance(param, jinja2.nodes.Name):
Expand Down
18 changes: 15 additions & 3 deletions tests/unit/test_jinja_macro_call.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from dbt_common.clients.jinja import MacroType
from dbt_common.clients.jinja_macro_call import PRIMITIVE_TYPES, DBT_CLASSES, FailureType, MacroCallChecker, MacroChecker
from dbt_common.clients.jinja_macro_call import (
PRIMITIVE_TYPES,
DBT_CLASSES,
FailureType,
MacroCallChecker,
MacroChecker,
)

single_param_macro_text = """{% macro call_me(param: TYPE) %}
{% endmacro %}"""
Expand All @@ -23,6 +29,7 @@ def test_dbt_class_type_checks() -> None:
failures = call.check(macro_text)
assert not failures


def test_type_checks_wrong() -> None:
"""Test that calls to annotated macros with incorrect types fail type checks."""
for type_name in PRIMITIVE_TYPES + DBT_CLASSES:
Expand Down Expand Up @@ -62,15 +69,19 @@ def test_too_few_pos_args() -> None:


def test_unknown_kwarg() -> None:
call = MacroCallChecker("call_me", "", [MacroType("int"), MacroType("int")], {"unk": MacroType("str")})
call = MacroCallChecker(
"call_me", "", [MacroType("int"), MacroType("int")], {"unk": MacroType("str")}
)
failures = call.check(kwarg_param_macro_text)
assert len(failures) == 1
assert failures[0].type == FailureType.EXTRA_ARGUMENT


def test_kwarg_type() -> None:
"""Test that annotated kwargs pass type checks when used by name."""
call = MacroCallChecker("call_me", "", [MacroType("int"), MacroType("int")], {"arg3": MacroType("str")})
call = MacroCallChecker(
"call_me", "", [MacroType("int"), MacroType("int")], {"arg3": MacroType("str")}
)
failures = call.check(kwarg_param_macro_text)
assert not failures

Expand All @@ -81,6 +92,7 @@ def test_wrong_kwarg_type() -> None:
failures = call.check(kwarg_param_macro_text)
assert failures[0].type == FailureType.TYPE_MISMATCH


# TODO: Test detection of macro with invalid default value for param type
# TODO: Test detection of macro called with invalid variable parameter, as known from macro parameter annotation.

Expand Down

0 comments on commit 7b081b3

Please sign in to comment.