Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Mergeable for SqlExpressionTreeLineage #1570

Merged
merged 2 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _reduce_parents(

@staticmethod
def _statement_contains_difficult_expressions(node: SqlSelectStatementNode) -> bool:
combined_lineage = SqlExpressionTreeLineage.combine(
combined_lineage = SqlExpressionTreeLineage.merge_iterable(
tuple(x.expr.lineage for x in node.select_columns)
+ ((node.where.lineage,) if node.where else ())
+ tuple(x.expr.lineage for x in node.group_bys)
Expand All @@ -133,7 +133,7 @@ def _statement_contains_difficult_expressions(node: SqlSelectStatementNode) -> b

@staticmethod
def _select_columns_contain_string_expressions(select_columns: Tuple[SqlSelectColumn, ...]) -> bool:
combined_lineage = SqlExpressionTreeLineage.combine(tuple(x.expr.lineage for x in select_columns))
combined_lineage = SqlExpressionTreeLineage.merge_iterable(tuple(x.expr.lineage for x in select_columns))

return len(combined_lineage.string_exprs) > 0

Expand Down
2 changes: 1 addition & 1 deletion metricflow/sql/optimizer/tag_required_column_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _search_for_expressions(
if select_node.where:
all_expr_search_results.append(select_node.where.lineage)

return SqlExpressionTreeLineage.combine(all_expr_search_results)
return SqlExpressionTreeLineage.merge_iterable(all_expr_search_results)

@override
def visit_cte_node(self, node: SqlCteNode) -> None:
Expand Down
58 changes: 31 additions & 27 deletions metricflow/sql/sql_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from enum import Enum
from typing import Dict, Generic, List, Mapping, Optional, Sequence, Tuple

import more_itertools
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.protocols.measure import MeasureAggregationParameters
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagNode, DisplayedProperty
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet
Expand Down Expand Up @@ -100,7 +100,7 @@ def matches(self, other: SqlExpressionNode) -> bool:


@dataclass(frozen=True)
class SqlExpressionTreeLineage:
class SqlExpressionTreeLineage(Mergeable):
"""Captures the lineage of an expression node - contains itself and all ancestor nodes."""

string_exprs: Tuple[SqlStringExpression, ...] = ()
Expand All @@ -109,19 +109,6 @@ class SqlExpressionTreeLineage:
column_alias_reference_exprs: Tuple[SqlColumnAliasReferenceExpression, ...] = ()
other_exprs: Tuple[SqlExpressionNode, ...] = ()

@staticmethod
def combine(lineages: Sequence[SqlExpressionTreeLineage]) -> SqlExpressionTreeLineage:
"""Combine multiple lineages into one lineage, without de-duping."""
return SqlExpressionTreeLineage(
string_exprs=tuple(more_itertools.flatten(tuple(x.string_exprs for x in lineages))),
function_exprs=tuple(more_itertools.flatten(tuple(x.function_exprs for x in lineages))),
column_reference_exprs=tuple(more_itertools.flatten(tuple(x.column_reference_exprs for x in lineages))),
column_alias_reference_exprs=tuple(
more_itertools.flatten(tuple(x.column_alias_reference_exprs for x in lineages))
),
other_exprs=tuple(more_itertools.flatten(tuple(x.other_exprs for x in lineages))),
)

@property
def contains_string_exprs(self) -> bool: # noqa: D102
return len(self.string_exprs) > 0
Expand All @@ -138,6 +125,21 @@ def contains_ambiguous_exprs(self) -> bool: # noqa: D102
def contains_aggregate_exprs(self) -> bool: # noqa: D102
return any(x.is_aggregate_function for x in self.function_exprs)

@override
def merge(self, other: SqlExpressionTreeLineage) -> SqlExpressionTreeLineage:
return SqlExpressionTreeLineage(
string_exprs=self.string_exprs + other.string_exprs,
function_exprs=self.function_exprs + other.function_exprs,
column_reference_exprs=self.column_reference_exprs + other.column_reference_exprs,
column_alias_reference_exprs=self.column_alias_reference_exprs + other.column_alias_reference_exprs,
other_exprs=self.other_exprs + other.other_exprs,
)

@classmethod
@override
def empty_instance(cls) -> SqlExpressionTreeLineage:
return SqlExpressionTreeLineage()


class SqlColumnReplacements:
"""When re-writing column references in expressions, this stores the mapping."""
Expand Down Expand Up @@ -604,7 +606,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -803,7 +805,7 @@ def is_aggregate_function(self) -> bool: # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(function_exprs=(self,)),)
)

Expand Down Expand Up @@ -923,7 +925,7 @@ def is_aggregate_function(self) -> bool: # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(function_exprs=(self,)),)
)

Expand Down Expand Up @@ -1084,7 +1086,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(function_exprs=(self,)),)
)

Expand Down Expand Up @@ -1194,7 +1196,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1241,7 +1243,9 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine([self.arg.lineage, SqlExpressionTreeLineage(other_exprs=(self,))])
return SqlExpressionTreeLineage.merge_iterable(
[self.arg.lineage, SqlExpressionTreeLineage(other_exprs=(self,))]
)

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
if not isinstance(other, SqlIsNullExpression):
Expand Down Expand Up @@ -1304,7 +1308,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1351,7 +1355,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1402,7 +1406,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1461,7 +1465,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1526,7 +1530,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1591,7 +1595,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down
Loading