From 7d1eff28cf6456d118afddfc7dabc804565b4588 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Tue, 10 Dec 2024 18:08:44 -0800 Subject: [PATCH] /* PR_START p--misc 06 */ Implement `Mergeable` for `SqlExpressionTreeLineage` --- .../optimizer/rewriting_sub_query_reducer.py | 4 +- .../optimizer/tag_required_column_aliases.py | 2 +- metricflow/sql/sql_exprs.py | 61 ++++++++++++------- 3 files changed, 43 insertions(+), 24 deletions(-) diff --git a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py index 053f0d7f4..bd4fbb87f 100644 --- a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py +++ b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py @@ -122,7 +122,7 @@ def _reduce_parents( @staticmethod def _statement_contains_difficult_expressions(node: SqlSelectStatementNode) -> bool: - combined_lineage = SqlExpressionTreeLineage.combine( + combined_lineage = SqlExpressionTreeLineage.merge_iterable( tuple(x.expr.lineage for x in node.select_columns) + ((node.where.lineage,) if node.where else ()) + tuple(x.expr.lineage for x in node.group_bys) @@ -133,7 +133,7 @@ def _statement_contains_difficult_expressions(node: SqlSelectStatementNode) -> b @staticmethod def _select_columns_contain_string_expressions(select_columns: Tuple[SqlSelectColumn, ...]) -> bool: - combined_lineage = SqlExpressionTreeLineage.combine(tuple(x.expr.lineage for x in select_columns)) + combined_lineage = SqlExpressionTreeLineage.merge_iterable(tuple(x.expr.lineage for x in select_columns)) return len(combined_lineage.string_exprs) > 0 diff --git a/metricflow/sql/optimizer/tag_required_column_aliases.py b/metricflow/sql/optimizer/tag_required_column_aliases.py index 81a09f009..32dfacd32 100644 --- a/metricflow/sql/optimizer/tag_required_column_aliases.py +++ b/metricflow/sql/optimizer/tag_required_column_aliases.py @@ -102,7 +102,7 @@ def _search_for_expressions( if select_node.where: all_expr_search_results.append(select_node.where.lineage) - return SqlExpressionTreeLineage.combine(all_expr_search_results) + return SqlExpressionTreeLineage.merge_iterable(all_expr_search_results) @override def visit_cte_node(self, node: SqlCteNode) -> None: diff --git a/metricflow/sql/sql_exprs.py b/metricflow/sql/sql_exprs.py index 62d8e9687..9c0a0674d 100644 --- a/metricflow/sql/sql_exprs.py +++ b/metricflow/sql/sql_exprs.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Dict, Generic, List, Mapping, Optional, Sequence, Tuple +from typing import Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple import more_itertools from dbt_semantic_interfaces.enum_extension import assert_values_exhausted @@ -15,6 +15,7 @@ from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity +from metricflow_semantics.collection_helpers.merger import Mergeable from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DagNode, DisplayedProperty from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet @@ -100,7 +101,7 @@ def matches(self, other: SqlExpressionNode) -> bool: @dataclass(frozen=True) -class SqlExpressionTreeLineage: +class SqlExpressionTreeLineage(Mergeable): """Captures the lineage of an expression node - contains itself and all ancestor nodes.""" string_exprs: Tuple[SqlStringExpression, ...] = () @@ -109,17 +110,18 @@ class SqlExpressionTreeLineage: column_alias_reference_exprs: Tuple[SqlColumnAliasReferenceExpression, ...] = () other_exprs: Tuple[SqlExpressionNode, ...] = () - @staticmethod - def combine(lineages: Sequence[SqlExpressionTreeLineage]) -> SqlExpressionTreeLineage: + @classmethod + @override + def merge_iterable(cls, items: Iterable[SqlExpressionTreeLineage]) -> SqlExpressionTreeLineage: """Combine multiple lineages into one lineage, without de-duping.""" return SqlExpressionTreeLineage( - string_exprs=tuple(more_itertools.flatten(tuple(x.string_exprs for x in lineages))), - function_exprs=tuple(more_itertools.flatten(tuple(x.function_exprs for x in lineages))), - column_reference_exprs=tuple(more_itertools.flatten(tuple(x.column_reference_exprs for x in lineages))), + string_exprs=tuple(more_itertools.flatten(tuple(x.string_exprs for x in items))), + function_exprs=tuple(more_itertools.flatten(tuple(x.function_exprs for x in items))), + column_reference_exprs=tuple(more_itertools.flatten(tuple(x.column_reference_exprs for x in items))), column_alias_reference_exprs=tuple( - more_itertools.flatten(tuple(x.column_alias_reference_exprs for x in lineages)) + more_itertools.flatten(tuple(x.column_alias_reference_exprs for x in items)) ), - other_exprs=tuple(more_itertools.flatten(tuple(x.other_exprs for x in lineages))), + other_exprs=tuple(more_itertools.flatten(tuple(x.other_exprs for x in items))), ) @property @@ -138,6 +140,21 @@ def contains_ambiguous_exprs(self) -> bool: # noqa: D102 def contains_aggregate_exprs(self) -> bool: # noqa: D102 return any(x.is_aggregate_function for x in self.function_exprs) + @override + def merge(self, other: SqlExpressionTreeLineage) -> SqlExpressionTreeLineage: + return SqlExpressionTreeLineage( + string_exprs=self.string_exprs + other.string_exprs, + function_exprs=self.function_exprs + other.function_exprs, + column_reference_exprs=self.column_reference_exprs + other.column_reference_exprs, + column_alias_reference_exprs=self.column_alias_reference_exprs + other.column_alias_reference_exprs, + other_exprs=self.other_exprs + other.other_exprs, + ) + + @classmethod + @override + def empty_instance(cls) -> SqlExpressionTreeLineage: + return SqlExpressionTreeLineage() + class SqlColumnReplacements: """When re-writing column references in expressions, this stores the mapping.""" @@ -604,7 +621,7 @@ def rewrite( # noqa: D102 @property def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 - return SqlExpressionTreeLineage.combine( + return SqlExpressionTreeLineage.merge_iterable( tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),) ) @@ -803,7 +820,7 @@ def is_aggregate_function(self) -> bool: # noqa: D102 @property def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 - return SqlExpressionTreeLineage.combine( + return SqlExpressionTreeLineage.merge_iterable( tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(function_exprs=(self,)),) ) @@ -923,7 +940,7 @@ def is_aggregate_function(self) -> bool: # noqa: D102 @property def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 - return SqlExpressionTreeLineage.combine( + return SqlExpressionTreeLineage.merge_iterable( tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(function_exprs=(self,)),) ) @@ -1084,7 +1101,7 @@ def rewrite( # noqa: D102 @property def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 - return SqlExpressionTreeLineage.combine( + return SqlExpressionTreeLineage.merge_iterable( tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(function_exprs=(self,)),) ) @@ -1194,7 +1211,7 @@ def rewrite( # noqa: D102 @property def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 - return SqlExpressionTreeLineage.combine( + return SqlExpressionTreeLineage.merge_iterable( tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),) ) @@ -1241,7 +1258,9 @@ def rewrite( # noqa: D102 @property def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 - return SqlExpressionTreeLineage.combine([self.arg.lineage, SqlExpressionTreeLineage(other_exprs=(self,))]) + return SqlExpressionTreeLineage.merge_iterable( + [self.arg.lineage, SqlExpressionTreeLineage(other_exprs=(self,))] + ) def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 if not isinstance(other, SqlIsNullExpression): @@ -1304,7 +1323,7 @@ def rewrite( # noqa: D102 @property def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 - return SqlExpressionTreeLineage.combine( + return SqlExpressionTreeLineage.merge_iterable( tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),) ) @@ -1351,7 +1370,7 @@ def rewrite( # noqa: D102 @property def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 - return SqlExpressionTreeLineage.combine( + return SqlExpressionTreeLineage.merge_iterable( tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),) ) @@ -1402,7 +1421,7 @@ def rewrite( # noqa: D102 @property def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 - return SqlExpressionTreeLineage.combine( + return SqlExpressionTreeLineage.merge_iterable( tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),) ) @@ -1461,7 +1480,7 @@ def rewrite( # noqa: D102 @property def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 - return SqlExpressionTreeLineage.combine( + return SqlExpressionTreeLineage.merge_iterable( tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),) ) @@ -1526,7 +1545,7 @@ def rewrite( # noqa: D102 @property def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 - return SqlExpressionTreeLineage.combine( + return SqlExpressionTreeLineage.merge_iterable( tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),) ) @@ -1591,7 +1610,7 @@ def rewrite( # noqa: D102 @property def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 - return SqlExpressionTreeLineage.combine( + return SqlExpressionTreeLineage.merge_iterable( tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),) )