Skip to content

Commit

Permalink
Merge branch 'main' into mashumaro_fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank committed Dec 11, 2024
2 parents 56b3592 + 243568e commit 2c0e434
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 3 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20241210-144247.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add syntax support for types on macro parameters.
time: 2024-12-10T14:42:47.253157-05:00
custom:
Author: peterallenwebb
Issue: "229"
66 changes: 66 additions & 0 deletions dbt_common/clients/jinja.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import codecs
import dataclasses
import linecache
import os
import tempfile
Expand Down Expand Up @@ -90,6 +91,12 @@ def _linecache_inject(source: str, write: bool) -> str:
return filename


@dataclasses.dataclass
class MacroType:
name: str
type_params: List["MacroType"] = dataclasses.field(default_factory=list)


class MacroFuzzParser(jinja2.parser.Parser):
def parse_macro(self) -> jinja2.nodes.Macro:
node = jinja2.nodes.Macro(lineno=next(self.stream).lineno)
Expand All @@ -103,6 +110,65 @@ def parse_macro(self) -> jinja2.nodes.Macro:
node.body = self.parse_statements(("name:endmacro",), drop_needle=True)
return node

def parse_signature(self, node: Union[jinja2.nodes.Macro, jinja2.nodes.CallBlock]) -> None:
"""Overrides the default jinja Parser.parse_signature method, modifying
the original implementation to allow macros to have typed parameters."""

# Jinja does not support extending its node types, such as Macro, so
# at least while typed macros are experimental, we will patch the
# information onto the existing types.
setattr(node, "arg_types", [])
setattr(node, "has_type_annotations", False)

args = node.args = [] # type: ignore
defaults = node.defaults = [] # type: ignore

self.stream.expect("lparen")
while self.stream.current.type != "rparen":
if args:
self.stream.expect("comma")

arg = self.parse_assign_target(name_only=True)
arg.set_ctx("param")

type_name: Optional[str]
if self.stream.skip_if("colon"):
node.has_type_annotations = True # type: ignore
type_name = self.parse_type_name()
else:
type_name = ""

node.arg_types.append(type_name) # type: ignore

if self.stream.skip_if("assign"):
defaults.append(self.parse_expression())
elif defaults:
self.fail("non-default argument follows default argument")

args.append(arg)
self.stream.expect("rparen")

def parse_type_name(self) -> MacroType:
# NOTE: Types syntax is validated here, but not whether type names
# are valid or have correct parameters.

# A type name should consist of a name (i.e. 'Dict')...
type_name = self.stream.expect("name").value
type = MacroType(type_name)

# ..and an optional comma-delimited list of type parameters
# as in the type declaration 'Dict[str, str]'
if self.stream.skip_if("lbracket"):
while self.stream.current.type != "rbracket":
if type.type_params:
self.stream.expect("comma")
param_type = self.parse_type_name()
type.type_params.append(param_type)

self.stream.expect("rbracket")

return type


class MacroFuzzEnvironment(jinja2.sandbox.SandboxedEnvironment):
def _parse(
Expand Down
50 changes: 47 additions & 3 deletions tests/unit/test_jinja.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import jinja2
import unittest

from typing import Any, Dict

from dbt_common.clients._jinja_blocks import BlockTag
from dbt_common.clients.jinja import extract_toplevel_blocks, get_template, render_template
from dbt_common.clients.jinja import (
extract_toplevel_blocks,
get_template,
render_template,
MacroFuzzParser,
MacroType,
)
from dbt_common.exceptions import CompilationError


Expand Down Expand Up @@ -505,7 +514,7 @@ def test_if_endfor_newlines(self) -> None:
"""


def test_if_list_filter():
def test_if_list_filter() -> None:
jinja_string = """
{%- if my_var | is_list -%}
Found a list
Expand All @@ -514,7 +523,7 @@ def test_if_list_filter():
{%- endif -%}
"""
# Check with list variable
ctx = {"my_var": ["one", "two"]}
ctx: Dict[str, Any] = {"my_var": ["one", "two"]}
template = get_template(jinja_string, ctx)
rendered = render_template(template, ctx)
assert "Found a list" in rendered
Expand All @@ -524,3 +533,38 @@ def test_if_list_filter():
template = get_template(jinja_string, ctx)
rendered = render_template(template, ctx)
assert "Did not find a list" in rendered


def test_macro_parser_parses_simple_types() -> None:
macro_txt = """
{% macro test_macro(param1: str, param2: int, param3: bool, param4: float, param5: Any) %}
{% endmacro %}
"""

env = jinja2.Environment()
parser = MacroFuzzParser(env, macro_txt)
result = parser.parse()
arg_types = result.body[1].arg_types
assert arg_types[0] == MacroType("str")
assert arg_types[1] == MacroType("int")
assert arg_types[2] == MacroType("bool")
assert arg_types[3] == MacroType("float")
assert arg_types[4] == MacroType("Any")


def test_macro_parser_parses_complex_types() -> None:
macro_txt = """
{% macro test_macro(param1: List[str], param2: Dict[ int,str ], param3: Optional[List[str]], param4: Dict[str, Dict[bool, Any]]) %}
{% endmacro %}
"""

env = jinja2.Environment()
parser = MacroFuzzParser(env, macro_txt)
result = parser.parse()
arg_types = result.body[1].arg_types
assert arg_types[0] == MacroType("List", [MacroType("str")])
assert arg_types[1] == MacroType("Dict", [MacroType("int"), MacroType("str")])
assert arg_types[2] == MacroType("Optional", [MacroType("List", [MacroType("str")])])
assert arg_types[3] == MacroType(
"Dict", [MacroType("str"), MacroType("Dict", [MacroType("bool"), MacroType("Any")])]
)

0 comments on commit 2c0e434

Please sign in to comment.