-
Notifications
You must be signed in to change notification settings - Fork 98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SQL exprs needed for custom offset window #1575
Changes from all commits
4afb212
e48f6ad
c142551
9fb7994
d46303f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,12 +14,13 @@ | |
from dbt_semantic_interfaces.type_enums.date_part import DatePart | ||
from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation | ||
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity | ||
from typing_extensions import override | ||
|
||
from metricflow_semantics.collection_helpers.merger import Mergeable | ||
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix | ||
from metricflow_semantics.dag.mf_dag import DagNode, DisplayedProperty | ||
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet | ||
from metricflow_semantics.visitor import Visitable, VisitorOutputT | ||
from typing_extensions import override | ||
|
||
|
||
@dataclass(frozen=True, eq=False) | ||
|
@@ -237,6 +238,18 @@ def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> Visit | |
def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> VisitorOutputT: # noqa: D102 | ||
pass | ||
|
||
@abstractmethod | ||
def visit_case_expr(self, node: SqlCaseExpression) -> VisitorOutputT: # noqa: D102 | ||
pass | ||
|
||
@abstractmethod | ||
def visit_arithmetic_expr(self, node: SqlArithmeticExpression) -> VisitorOutputT: # noqa: D102 | ||
pass | ||
|
||
@abstractmethod | ||
def visit_integer_expr(self, node: SqlIntegerExpression) -> VisitorOutputT: # noqa: D102 | ||
pass | ||
|
||
|
||
@dataclass(frozen=True, eq=False) | ||
class SqlStringExpression(SqlExpressionNode): | ||
|
@@ -375,6 +388,59 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 | |
return self.literal_value == other.literal_value | ||
|
||
|
||
@dataclass(frozen=True, eq=False) | ||
class SqlIntegerExpression(SqlExpressionNode): | ||
"""An integer like 1.""" | ||
|
||
integer_value: int | ||
|
||
@staticmethod | ||
def create(integer_value: int) -> SqlIntegerExpression: # noqa: D102 | ||
return SqlIntegerExpression(parent_nodes=(), integer_value=integer_value) | ||
|
||
@classmethod | ||
def id_prefix(cls) -> IdPrefix: # noqa: D102 | ||
return StaticIdPrefix.SQL_EXPR_INTEGER_PREFIX | ||
|
||
def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 | ||
return visitor.visit_integer_expr(self) | ||
|
||
@property | ||
def description(self) -> str: # noqa: D102 | ||
return f"Integer: {self.integer_value}" | ||
|
||
@property | ||
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 | ||
return tuple(super().displayed_properties) + (DisplayedProperty("value", self.integer_value),) | ||
|
||
@property | ||
def requires_parenthesis(self) -> bool: # noqa: D102 | ||
return False | ||
|
||
@property | ||
def bind_parameter_set(self) -> SqlBindParameterSet: # noqa: D102 | ||
return SqlBindParameterSet() | ||
|
||
def __repr__(self) -> str: # noqa: D105 | ||
return f"{self.__class__.__name__}(node_id={self.node_id}, integer_value={self.integer_value})" | ||
|
||
def rewrite( # noqa: D102 | ||
self, | ||
column_replacements: Optional[SqlColumnReplacements] = None, | ||
should_render_table_alias: Optional[bool] = None, | ||
) -> SqlExpressionNode: | ||
return self | ||
|
||
@property | ||
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 | ||
return SqlExpressionTreeLineage(other_exprs=(self,)) | ||
|
||
def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 | ||
if not isinstance(other, SqlIntegerExpression): | ||
return False | ||
return self.integer_value == other.integer_value | ||
|
||
|
||
@dataclass(frozen=True) | ||
class SqlColumnReference: | ||
"""Used with string expressions to specify what columns are referred to in the string expression.""" | ||
|
@@ -950,17 +1016,38 @@ class SqlWindowFunction(Enum): | |
FIRST_VALUE = "FIRST_VALUE" | ||
LAST_VALUE = "LAST_VALUE" | ||
AVERAGE = "AVG" | ||
ROW_NUMBER = "ROW_NUMBER" | ||
LEAD = "LEAD" | ||
|
||
@property | ||
def requires_ordering(self) -> bool: | ||
"""Asserts whether or not ordering the window function will have an impact on the resulting value.""" | ||
if self is SqlWindowFunction.FIRST_VALUE or self is SqlWindowFunction.LAST_VALUE: | ||
if ( | ||
self is SqlWindowFunction.FIRST_VALUE | ||
or self is SqlWindowFunction.LAST_VALUE | ||
or self is SqlWindowFunction.ROW_NUMBER | ||
or self is SqlWindowFunction.LEAD | ||
): | ||
return True | ||
elif self is SqlWindowFunction.AVERAGE: | ||
return False | ||
else: | ||
assert_values_exhausted(self) | ||
|
||
@property | ||
def allows_frame_clause(self) -> bool: | ||
"""Whether the function allows a frame clause, e.g., 'ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING'.""" | ||
if ( | ||
self is SqlWindowFunction.FIRST_VALUE | ||
or self is SqlWindowFunction.LAST_VALUE | ||
or self is SqlWindowFunction.AVERAGE | ||
): | ||
return True | ||
if self is SqlWindowFunction.ROW_NUMBER or self is SqlWindowFunction.LEAD: | ||
return False | ||
else: | ||
assert_values_exhausted(self) | ||
|
||
@classmethod | ||
def get_window_function_for_period_agg(cls, period_agg: PeriodAggregation) -> SqlWindowFunction: | ||
"""Get the window function to use for given period agg option.""" | ||
|
@@ -1106,7 +1193,8 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 | |
return ( | ||
self.sql_function == other.sql_function | ||
and self.order_by_args == other.order_by_args | ||
and self._parents_match(other) | ||
and self.partition_by_args == other.partition_by_args | ||
and self.sql_function_args == other.sql_function_args | ||
) | ||
|
||
|
||
|
@@ -1367,7 +1455,7 @@ def rewrite( # noqa: D102 | |
) -> SqlExpressionNode: | ||
return SqlAddTimeExpression.create( | ||
arg=self.arg.rewrite(column_replacements, should_render_table_alias), | ||
count_expr=self.count_expr, | ||
count_expr=self.count_expr.rewrite(column_replacements, should_render_table_alias), | ||
granularity=self.granularity, | ||
) | ||
|
||
|
@@ -1719,3 +1807,157 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 | |
|
||
def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 | ||
return False | ||
|
||
|
||
@dataclass(frozen=True, eq=False) | ||
class SqlCaseExpression(SqlExpressionNode): | ||
"""Renders a CASE WHEN expression.""" | ||
|
||
when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode] | ||
else_expr: Optional[SqlExpressionNode] | ||
|
||
@staticmethod | ||
def create( # noqa: D102 | ||
when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode], else_expr: Optional[SqlExpressionNode] = None | ||
) -> SqlCaseExpression: | ||
parent_nodes: Tuple[SqlExpressionNode, ...] = () | ||
for when, then in when_to_then_exprs.items(): | ||
parent_nodes += (when, then) | ||
|
||
if else_expr: | ||
parent_nodes += (else_expr,) | ||
|
||
return SqlCaseExpression(parent_nodes=parent_nodes, when_to_then_exprs=when_to_then_exprs, else_expr=else_expr) | ||
|
||
@classmethod | ||
def id_prefix(cls) -> IdPrefix: # noqa: D102 | ||
return StaticIdPrefix.SQL_EXPR_CASE_PREFIX | ||
|
||
def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 | ||
return visitor.visit_case_expr(self) | ||
|
||
@property | ||
def description(self) -> str: # noqa: D102 | ||
return "Case expression" | ||
|
||
@property | ||
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 | ||
return super().displayed_properties | ||
|
||
@property | ||
def requires_parenthesis(self) -> bool: # noqa: D102 | ||
return False | ||
|
||
@property | ||
def bind_parameter_set(self) -> SqlBindParameterSet: # noqa: D102 | ||
return SqlBindParameterSet() | ||
|
||
def __repr__(self) -> str: # noqa: D105 | ||
return f"{self.__class__.__name__}(node_id={self.node_id})" | ||
|
||
def rewrite( # noqa: D102 | ||
self, | ||
column_replacements: Optional[SqlColumnReplacements] = None, | ||
should_render_table_alias: Optional[bool] = None, | ||
) -> SqlExpressionNode: | ||
return SqlCaseExpression.create( | ||
when_to_then_exprs={ | ||
when.rewrite(column_replacements, should_render_table_alias): then.rewrite( | ||
column_replacements, should_render_table_alias | ||
) | ||
for when, then in self.when_to_then_exprs.items() | ||
}, | ||
else_expr=( | ||
self.else_expr.rewrite(column_replacements, should_render_table_alias) if self.else_expr else None | ||
), | ||
) | ||
|
||
@property | ||
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 | ||
return SqlExpressionTreeLineage.merge_iterable( | ||
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),) | ||
) | ||
|
||
def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 | ||
if not isinstance(other, SqlCaseExpression): | ||
return False | ||
return self.when_to_then_exprs == other.when_to_then_exprs and self.else_expr == other.else_expr | ||
|
||
|
||
class SqlArithmeticOperator(Enum): | ||
"""Arithmetic operator used to do math in a SQL expression.""" | ||
|
||
ADD = "+" | ||
SUBTRACT = "-" | ||
MULTIPLY = "*" | ||
DIVIDE = "/" | ||
|
||
|
||
@dataclass(frozen=True, eq=False) | ||
class SqlArithmeticExpression(SqlExpressionNode): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For potentially cleanup, seems like this can replace There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I considered that but decided not to because the ratio expression has a bunch of extra logic related to casting & |
||
"""An arithmetic expression using +, -, *, /. | ||
|
||
e.g. my_table.my_column + my_table.other_column | ||
|
||
Attributes: | ||
left_expr: The expression on the left side of the operator | ||
operator: The operator to use on the expressions | ||
right_expr: The expression on the right side of the operator | ||
""" | ||
|
||
left_expr: SqlExpressionNode | ||
operator: SqlArithmeticOperator | ||
right_expr: SqlExpressionNode | ||
|
||
@staticmethod | ||
def create( # noqa: D102 | ||
left_expr: SqlExpressionNode, operator: SqlArithmeticOperator, right_expr: SqlExpressionNode | ||
) -> SqlArithmeticExpression: | ||
return SqlArithmeticExpression( | ||
parent_nodes=(left_expr, right_expr), left_expr=left_expr, operator=operator, right_expr=right_expr | ||
) | ||
|
||
@classmethod | ||
def id_prefix(cls) -> IdPrefix: # noqa: D102 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not for this PR as we can do a batch cleanup, but we can now use |
||
return StaticIdPrefix.SQL_EXPR_ARITHMETIC_PREFIX | ||
|
||
def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 | ||
return visitor.visit_arithmetic_expr(self) | ||
|
||
@property | ||
def description(self) -> str: # noqa: D102 | ||
return "Arithmetic Expression" | ||
|
||
@property | ||
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 | ||
return tuple(super().displayed_properties) + ( | ||
DisplayedProperty("left_expr", self.left_expr), | ||
DisplayedProperty("operator", self.operator.value), | ||
DisplayedProperty("right_expr", self.right_expr), | ||
) | ||
|
||
@property | ||
def requires_parenthesis(self) -> bool: # noqa: D102 | ||
return True | ||
|
||
def rewrite( # noqa: D102 | ||
self, | ||
column_replacements: Optional[SqlColumnReplacements] = None, | ||
should_render_table_alias: Optional[bool] = None, | ||
) -> SqlExpressionNode: | ||
return SqlArithmeticExpression.create( | ||
left_expr=self.left_expr.rewrite(column_replacements, should_render_table_alias), | ||
operator=self.operator, | ||
right_expr=self.right_expr.rewrite(column_replacements, should_render_table_alias), | ||
) | ||
|
||
@property | ||
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 | ||
return SqlExpressionTreeLineage.merge_iterable( | ||
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),) | ||
) | ||
|
||
def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 | ||
if not isinstance(other, SqlArithmeticExpression): | ||
return False | ||
return self.operator == other.operator and self._parents_match(other) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for this PR as we can do a batch cleanup, but this can be moved to a common method in the base class.