Skip to content

Commit

Permalink
Add SQL exprs needed for custom offset window
Browse files Browse the repository at this point in the history
This includes a CASE expression, an integer expression, and some updates to the add time expression & the window function expression.
  • Loading branch information
courtneyholcomb committed Dec 18, 2024
1 parent 2e1cabb commit b1435ef
Show file tree
Hide file tree
Showing 7 changed files with 331 additions and 30 deletions.
3 changes: 3 additions & 0 deletions metricflow-semantics/metricflow_semantics/dag/id_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):
SQL_EXPR_BETWEEN_PREFIX = "betw"
SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX = "wfnc"
SQL_EXPR_GENERATE_UUID_PREFIX = "uuid"
SQL_EXPR_CASE_PREFIX = "case"
SQL_EXPR_ARITHMETIC_PREFIX = "arit"
SQL_EXPR_INTEGER_PREFIX = "int"

SQL_PLAN_SELECT_STATEMENT_ID_PREFIX = "ss"
SQL_PLAN_TABLE_FROM_CLAUSE_ID_PREFIX = "tfc"
Expand Down
237 changes: 233 additions & 4 deletions metricflow-semantics/metricflow_semantics/sql/sql_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from typing_extensions import override

from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagNode, DisplayedProperty
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet
from metricflow_semantics.visitor import Visitable, VisitorOutputT
from typing_extensions import override


@dataclass(frozen=True, eq=False)
Expand Down Expand Up @@ -237,6 +238,18 @@ def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> Visit
def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_case_expr(self, node: SqlCaseExpression) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_arithmetic_expr(self, node: SqlArithmeticExpression) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_integer_expr(self, node: SqlIntegerExpression) -> VisitorOutputT: # noqa: D102
pass


@dataclass(frozen=True, eq=False)
class SqlStringExpression(SqlExpressionNode):
Expand Down Expand Up @@ -375,6 +388,59 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self.literal_value == other.literal_value


@dataclass(frozen=True, eq=False)
class SqlIntegerExpression(SqlExpressionNode):
"""An integer like 1."""

integer_value: int

@staticmethod
def create(integer_value: int) -> SqlIntegerExpression: # noqa: D102
return SqlIntegerExpression(parent_nodes=(), integer_value=integer_value)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.SQL_EXPR_INTEGER_PREFIX

def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_integer_expr(self)

@property
def description(self) -> str: # noqa: D102
return f"Integer: {self.integer_value}"

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (DisplayedProperty("value", self.integer_value),)

@property
def requires_parenthesis(self) -> bool: # noqa: D102
return False

@property
def bind_parameter_set(self) -> SqlBindParameterSet: # noqa: D102
return SqlBindParameterSet()

def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}(node_id={self.node_id}, integer_value={self.integer_value})"

def rewrite( # noqa: D102
self,
column_replacements: Optional[SqlColumnReplacements] = None,
should_render_table_alias: Optional[bool] = None,
) -> SqlExpressionNode:
return self

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage(other_exprs=(self,))

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
if not isinstance(other, SqlIntegerExpression):
return False
return self.integer_value == other.integer_value


@dataclass(frozen=True)
class SqlColumnReference:
"""Used with string expressions to specify what columns are referred to in the string expression."""
Expand Down Expand Up @@ -950,11 +1016,18 @@ class SqlWindowFunction(Enum):
FIRST_VALUE = "FIRST_VALUE"
LAST_VALUE = "LAST_VALUE"
AVERAGE = "AVG"
ROW_NUMBER = "ROW_NUMBER"
LAG = "LAG"

@property
def requires_ordering(self) -> bool:
"""Asserts whether or not ordering the window function will have an impact on the resulting value."""
if self is SqlWindowFunction.FIRST_VALUE or self is SqlWindowFunction.LAST_VALUE:
if (
self is SqlWindowFunction.FIRST_VALUE
or self is SqlWindowFunction.LAST_VALUE
or self is SqlWindowFunction.ROW_NUMBER
or self is SqlWindowFunction.LAG
):
return True
elif self is SqlWindowFunction.AVERAGE:
return False
Expand Down Expand Up @@ -1106,7 +1179,8 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return (
self.sql_function == other.sql_function
and self.order_by_args == other.order_by_args
and self._parents_match(other)
and self.partition_by_args == other.partition_by_args
and self.sql_function_args == other.sql_function_args
)


Expand Down Expand Up @@ -1367,7 +1441,7 @@ def rewrite( # noqa: D102
) -> SqlExpressionNode:
return SqlAddTimeExpression.create(
arg=self.arg.rewrite(column_replacements, should_render_table_alias),
count_expr=self.count_expr,
count_expr=self.count_expr.rewrite(column_replacements, should_render_table_alias),
granularity=self.granularity,
)

Expand Down Expand Up @@ -1719,3 +1793,158 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False


@dataclass(frozen=True, eq=False)
class SqlCaseExpression(SqlExpressionNode):
"""Renders a CASE WHEN expression."""

when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode]
else_expr: Optional[SqlExpressionNode]

@staticmethod
def create( # noqa: D102
when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode], else_expr: Optional[SqlExpressionNode] = None
) -> SqlCaseExpression:
parent_nodes: Tuple[SqlExpressionNode, ...] = ()
for when, then in when_to_then_exprs.items():
parent_nodes += (when,)
parent_nodes += (then,)

if else_expr:
parent_nodes += (else_expr,)

return SqlCaseExpression(parent_nodes=parent_nodes, when_to_then_exprs=when_to_then_exprs, else_expr=else_expr)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.SQL_EXPR_CASE_PREFIX

def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_case_expr(self)

@property
def description(self) -> str: # noqa: D102
return "Case expression"

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return super().displayed_properties

@property
def requires_parenthesis(self) -> bool: # noqa: D102
return False

@property
def bind_parameter_set(self) -> SqlBindParameterSet: # noqa: D102
return SqlBindParameterSet()

def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}(node_id={self.node_id})"

def rewrite( # noqa: D102
self,
column_replacements: Optional[SqlColumnReplacements] = None,
should_render_table_alias: Optional[bool] = None,
) -> SqlExpressionNode:
return SqlCaseExpression.create(
when_to_then_exprs={
when.rewrite(column_replacements, should_render_table_alias): then.rewrite(
column_replacements, should_render_table_alias
)
for when, then in self.when_to_then_exprs.items()
},
else_expr=(
self.else_expr.rewrite(column_replacements, should_render_table_alias) if self.else_expr else None
),
)

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
if not isinstance(other, SqlCaseExpression):
return False
return self.when_to_then_exprs == other.when_to_then_exprs and self.else_expr == other.else_expr


class SqlArithmeticOperator(Enum):
"""Arithmetic operator used to do math in a SQL expression."""

ADD = "+"
SUBTRACT = "-"
MULTIPLY = "*"
DIVIDE = "/"


@dataclass(frozen=True, eq=False)
class SqlArithmeticExpression(SqlExpressionNode):
"""An arithmetic expression using +, -, *, /.
e.g. my_table.my_column + my_table.other_column
Attributes:
left_expr: The expression on the left side of the operator
operator: The operator to use on the expressions
right_expr: The expression on the right side of the operator
"""

left_expr: SqlExpressionNode
operator: SqlArithmeticOperator
right_expr: SqlExpressionNode

@staticmethod
def create( # noqa: D102
left_expr: SqlExpressionNode, operator: SqlArithmeticOperator, right_expr: SqlExpressionNode
) -> SqlArithmeticExpression:
return SqlArithmeticExpression(
parent_nodes=(left_expr, right_expr), left_expr=left_expr, operator=operator, right_expr=right_expr
)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.SQL_EXPR_ARITHMETIC_PREFIX

def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_arithmetic_expr(self)

@property
def description(self) -> str: # noqa: D102
return "Arithmetic Expression"

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (
DisplayedProperty("left_expr", self.left_expr),
DisplayedProperty("operator", self.operator.value),
DisplayedProperty("right_expr", self.right_expr),
)

@property
def requires_parenthesis(self) -> bool: # noqa: D102
return True

def rewrite( # noqa: D102
self,
column_replacements: Optional[SqlColumnReplacements] = None,
should_render_table_alias: Optional[bool] = None,
) -> SqlExpressionNode:
return SqlArithmeticExpression.create(
left_expr=self.left_expr.rewrite(column_replacements, should_render_table_alias),
operator=self.operator,
right_expr=self.right_expr.rewrite(column_replacements, should_render_table_alias),
)

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
if not isinstance(other, SqlArithmeticExpression):
return False
return self.operator == other.operator and self._parents_match(other)
4 changes: 2 additions & 2 deletions metricflow/sql/render/big_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRender
count = node.count_expr.accept(self)

return SqlExpressionRenderResult(
sql=f"DATE_ADD(CAST({column.sql} AS {self.timestamp_data_type}), INTERVAL {count} {node.granularity.value})",
bind_parameter_set=column.bind_parameter_set,
sql=f"DATE_ADD(CAST({column.sql} AS {self.timestamp_data_type}), INTERVAL {count.sql} {node.granularity.value})",
bind_parameter_set=column.bind_parameter_set.merge(count.bind_parameter_set),
)

@override
Expand Down
23 changes: 17 additions & 6 deletions metricflow/sql/render/duckdb_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet
from metricflow_semantics.sql.sql_exprs import (
SqlAddTimeExpression,
SqlArithmeticExpression,
SqlArithmeticOperator,
SqlGenerateUuidExpression,
SqlIntegerExpression,
SqlPercentileExpression,
SqlPercentileFunctionType,
SqlSubtractTimeIntervalExpression,
Expand Down Expand Up @@ -56,17 +59,25 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress
@override
def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult:
"""Render time delta expression for DuckDB, which requires slightly different syntax from other engines."""
arg_rendered = node.arg.accept(self)
count_rendered = node.count_expr.accept(self).sql

granularity = node.granularity
count_expr = node.count_expr
if granularity is TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count_rendered = f"({count_rendered} * 3)"
count_expr = SqlArithmeticExpression.create(
left_expr=node.count_expr,
operator=SqlArithmeticOperator.MULTIPLY,
right_expr=SqlIntegerExpression.create(3),
)

arg_rendered = node.arg.accept(self)
count_rendered = count_expr.accept(self)
count_sql = f"({count_rendered.sql})" if count_expr.requires_parenthesis else count_rendered.sql

return SqlExpressionRenderResult(
sql=f"{arg_rendered.sql} + INTERVAL {count_rendered} {granularity.value}",
bind_parameter_set=arg_rendered.bind_parameter_set,
sql=f"{arg_rendered.sql} + INTERVAL {count_sql} {granularity.value}",
bind_parameter_set=SqlBindParameterSet.merge_iterable(
(arg_rendered.bind_parameter_set, count_rendered.bind_parameter_set)
),
)

@override
Expand Down
Loading

0 comments on commit b1435ef

Please sign in to comment.