diff --git a/metricflow/sql/optimizer/column_pruner.py b/metricflow/sql/optimizer/column_pruner.py index a61c8db8f..3f4bdb102 100644 --- a/metricflow/sql/optimizer/column_pruner.py +++ b/metricflow/sql/optimizer/column_pruner.py @@ -1,13 +1,13 @@ from __future__ import annotations import logging -from typing import FrozenSet, Mapping +from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat from typing_extensions import override from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer -from metricflow.sql.optimizer.tag_column_aliases import TaggedColumnAliasSet -from metricflow.sql.optimizer.tag_required_column_aliases import SqlTagRequiredColumnAliasesVisitor +from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping +from metricflow.sql.optimizer.tag_required_column_aliases import SqlMapRequiredColumnAliasesVisitor from metricflow.sql.sql_plan import ( SqlCreateTableAsNode, SqlCteNode, @@ -29,7 +29,7 @@ class SqlColumnPrunerVisitor(SqlQueryPlanNodeVisitor[SqlQueryPlanNode]): def __init__( self, - required_alias_mapping: Mapping[SqlQueryPlanNode, FrozenSet[str]], + required_alias_mapping: NodeToColumnAliasMapping, ) -> None: """Constructor. @@ -42,7 +42,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP # Remove columns that are not needed from this SELECT statement because the parent SELECT statement doesn't # need them. However, keep columns that are in group bys because that changes the meaning of the query. # Similarly, if this node is a distinct select node, keep all columns as it may return a different result set. - required_column_aliases = self._required_alias_mapping.get(node) + required_column_aliases = self._required_alias_mapping.get_aliases(node) if required_column_aliases is None: logger.error( f"Did not find {node.node_id=} in the required alias mapping. Returning the non-pruned version " @@ -101,23 +101,31 @@ def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode: class SqlColumnPrunerOptimizer(SqlQueryPlanOptimizer): - """Removes unnecessary columns in the SELECT clauses.""" + """Removes unnecessary columns in the SELECT statements.""" def optimize(self, node: SqlQueryPlanNode) -> SqlQueryPlanNode: # noqa: D102 - # Can't prune columns without knowing the structure of the query. - if not node.as_select_node: + # ALl columns in the nearest SELECT node need to be kept as otherwise, the meaning of the query changes. + required_select_columns = node.nearest_select_columns() + + # Can't prune without knowing the structure of the query. + if required_select_columns is None: + logger.debug( + LazyFormat( + "The columns required at this node can't be determined, so skipping column pruning", + node=node.structure_text(), + required_select_columns=required_select_columns, + ) + ) return node - # Figure out which columns in which nodes are required. - tagged_column_alias_set = TaggedColumnAliasSet() - tagged_column_alias_set.tag_all_aliases_in_node(node.as_select_node) - tag_required_column_alias_visitor = SqlTagRequiredColumnAliasesVisitor( - tagged_column_alias_set=tagged_column_alias_set, + map_required_column_aliases_visitor = SqlMapRequiredColumnAliasesVisitor( + start_node=node, + required_column_aliases_in_start_node=frozenset( + [select_column.column_alias for select_column in required_select_columns] + ), ) - node.accept(tag_required_column_alias_visitor) + node.accept(map_required_column_aliases_visitor) - # Re-write the query, pruning columns in the SELECT that are not needed. - pruning_visitor = SqlColumnPrunerVisitor( - required_alias_mapping=tagged_column_alias_set.get_mapping(), - ) + # Re-write the query, removing unnecessary columns in the SELECT statements. + pruning_visitor = SqlColumnPrunerVisitor(map_required_column_aliases_visitor.required_column_alias_mapping) return node.accept(pruning_visitor) diff --git a/metricflow/sql/optimizer/tag_column_aliases.py b/metricflow/sql/optimizer/tag_column_aliases.py index 77b7ee41e..c54987eea 100644 --- a/metricflow/sql/optimizer/tag_column_aliases.py +++ b/metricflow/sql/optimizer/tag_column_aliases.py @@ -2,95 +2,32 @@ import logging from collections import defaultdict -from typing import Dict, FrozenSet, Iterable, Mapping, Set - -from typing_extensions import override +from typing import Dict, FrozenSet, Iterable, Set from metricflow.sql.sql_plan import ( - SqlCreateTableAsNode, - SqlCteNode, SqlQueryPlanNode, - SqlQueryPlanNodeVisitor, - SqlSelectQueryFromClauseNode, - SqlSelectStatementNode, - SqlTableNode, ) logger = logging.getLogger(__name__) -class TaggedColumnAliasSet: - """Keep track of column aliases in SELECT statements that have been tagged. - - The main use case for this class is to keep track of column aliases / columns that are required so that unnecessary - columns can be pruned. - - For example, in this query: - - SELECT source_0.col_0 AS col_0 - FROM ( - SELECT - example_table.col_0 - example_table.col_1 - FROM example_table - ) source_0 +class NodeToColumnAliasMapping: + """Mutable class for mapping a SQL node to an arbitrary set of column aliases for that node. - this class can be used to tag `example_table.col_0` but not tag `example_table.col_1` since it's not needed for the - query to run correctly. + * Alternatively, this can be described as mapping a location in the SQL query plan to a set of column aliases. + * See `SqlMapRequiredColumnAliasesVisitor` for the main use case for this class. + * This is a thin wrapper over a dict to aid readability. """ def __init__(self) -> None: # noqa: D107 self._node_to_tagged_aliases: Dict[SqlQueryPlanNode, Set[str]] = defaultdict(set) - def get_tagged_aliases(self, node: SqlQueryPlanNode) -> FrozenSet[str]: - """Return the given tagged column aliases associated with the given SQL node.""" + def get_aliases(self, node: SqlQueryPlanNode) -> FrozenSet[str]: + """Return the column aliases added for the given SQL node.""" return frozenset(self._node_to_tagged_aliases[node]) - def tag_alias(self, node: SqlQueryPlanNode, column_alias: str) -> None: # noqa: D102 + def add_alias(self, node: SqlQueryPlanNode, column_alias: str) -> None: # noqa: D102 return self._node_to_tagged_aliases[node].add(column_alias) - def tag_aliases(self, node: SqlQueryPlanNode, column_aliases: Iterable[str]) -> None: # noqa: D102 + def add_aliases(self, node: SqlQueryPlanNode, column_aliases: Iterable[str]) -> None: # noqa: D102 self._node_to_tagged_aliases[node].update(column_aliases) - - def tag_all_aliases_in_node(self, node: SqlQueryPlanNode) -> None: - """Convenience method that tags all column aliases in the given SQL node, where appropriate.""" - node.accept(_TagAllColumnAliasesInNodeVisitor(self)) - - def get_mapping(self) -> Mapping[SqlQueryPlanNode, FrozenSet[str]]: - """Return mapping view that associates a given SQL node with the tagged column aliases in that node.""" - return {node: frozenset(tagged_aliases) for node, tagged_aliases in self._node_to_tagged_aliases.items()} - - -class _TagAllColumnAliasesInNodeVisitor(SqlQueryPlanNodeVisitor[None]): - """Visitor to help implement `TaggedColumnAliasSet.tag_all_aliases_in_node`.""" - - def __init__(self, required_column_alias_collector: TaggedColumnAliasSet) -> None: - self._required_column_alias_collector = required_column_alias_collector - - @override - def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: - for select_column in node.select_columns: - self._required_column_alias_collector.tag_alias( - node=node, - column_alias=select_column.column_alias, - ) - - @override - def visit_table_node(self, node: SqlTableNode) -> None: - """Columns in a SQL table are not represented.""" - pass - - @override - def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> None: - """Columns in an arbitrary SQL query are not represented.""" - pass - - @override - def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> None: - for parent_node in node.parent_nodes: - parent_node.accept(self) - - @override - def visit_cte_node(self, node: SqlCteNode) -> None: - for parent_node in node.parent_nodes: - parent_node.accept(self) diff --git a/metricflow/sql/optimizer/tag_required_column_aliases.py b/metricflow/sql/optimizer/tag_required_column_aliases.py index 6b34fb545..d2557cf9d 100644 --- a/metricflow/sql/optimizer/tag_required_column_aliases.py +++ b/metricflow/sql/optimizer/tag_required_column_aliases.py @@ -2,12 +2,12 @@ import logging from collections import defaultdict -from typing import Dict, List, Set, Tuple +from typing import Dict, FrozenSet, List, Set, Tuple from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat from typing_extensions import override -from metricflow.sql.optimizer.tag_column_aliases import TaggedColumnAliasSet +from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping from metricflow.sql.sql_exprs import SqlExpressionTreeLineage from metricflow.sql.sql_plan import ( SqlCreateTableAsNode, @@ -23,21 +23,29 @@ logger = logging.getLogger(__name__) -class SqlTagRequiredColumnAliasesVisitor(SqlQueryPlanNodeVisitor[None]): - """To aid column pruning, traverse the SQL-query representation DAG and tag all column aliases that are required. +class SqlMapRequiredColumnAliasesVisitor(SqlQueryPlanNodeVisitor[None]): + """To aid column pruning, traverse the SQL-query representation DAG and map the SELECT columns needed at each node. - For example, for the query: + For example, the query: + -- SELECT node_id="select_0" SELECT source_0.col_0 AS col_0_renamed FROM ( + -- SELECT node_id="select_1 SELECT example_table.col_0 example_table.col_1 FROM example_table_0 ) source_0 - The top-level SQL node would have the column alias `col_0_renamed` tagged, and the SQL node associated with - `source_0` would have `col_0` tagged. Once tagged, the information can be used to prune the columns in the SELECT: + would generate the mapping: + + { + "select_0": {"col_0"}, + "select_1": {"col_0"), + } + + The mapping can be later used to rewrite the query to: SELECT source_0.col_0 AS col_0_renamed FROM ( @@ -47,15 +55,26 @@ class SqlTagRequiredColumnAliasesVisitor(SqlQueryPlanNodeVisitor[None]): ) source_0 """ - def __init__(self, tagged_column_alias_set: TaggedColumnAliasSet) -> None: + def __init__(self, start_node: SqlQueryPlanNode, required_column_aliases_in_start_node: FrozenSet[str]) -> None: """Initializer. Args: - tagged_column_alias_set: Stores the set of columns that are tagged. This will be updated as the visitor - traverses the SQL-query representation DAG. + start_node: The node where the traversal by this visitor will start. + required_column_aliases_in_start_node: The column aliases at the `start_node` that are required. """ - self._column_alias_tagger = tagged_column_alias_set - self._cte_alias_to_cte_node: Dict[str, SqlCteNode] = {} + # Stores the mapping of the SQL node to the required column aliases. This will be updated as the visitor + # traverses the SQL-query representation DAG. + self._current_required_column_alias_mapping = NodeToColumnAliasMapping() + self._current_required_column_alias_mapping.add_aliases(start_node, required_column_aliases_in_start_node) + + # Helps lookup the CTE node associated with a given CTE alias. A member variable is needed as any node in the + # SQL DAG can reference a CTE. + start_node_as_select_node = start_node.as_select_node + self._cte_alias_to_cte_node: Dict[str, SqlCteNode] = ( + {cte_source.cte_alias: cte_source for cte_source in start_node_as_select_node.cte_sources} + if start_node_as_select_node is not None + else {} + ) def _search_for_expressions( self, select_node: SqlSelectStatementNode, pruned_select_columns: Tuple[SqlSelectColumn, ...] @@ -90,8 +109,8 @@ def visit_cte_node(self, node: SqlCteNode) -> None: select_statement = node.select_statement # Copy the tagged aliases from the CTE to the SELECT since when visiting a SELECT, the CTE node (not the SELECT # in the CTE) was tagged with the required aliases. - required_column_aliases_in_this_node = self._column_alias_tagger.get_tagged_aliases(node) - self._column_alias_tagger.tag_aliases(select_statement, required_column_aliases_in_this_node) + required_column_aliases_in_this_node = self._current_required_column_alias_mapping.get_aliases(node) + self._current_required_column_alias_mapping.add_aliases(select_statement, required_column_aliases_in_this_node) # Visit parent nodes. select_statement.accept(self) @@ -105,16 +124,13 @@ def _tag_potential_cte_node(self, table_name: str, column_aliases: Set[str]) -> """A reference to a SQL table might be a CTE. If so, tag the appropriate aliases in the CTEs.""" cte_node = self._cte_alias_to_cte_node.get(table_name) if cte_node is not None: - self._column_alias_tagger.tag_aliases(cte_node, column_aliases) + self._current_required_column_alias_mapping.add_aliases(cte_node, column_aliases) # Propagate the required aliases to parents, which could be other CTEs. cte_node.accept(self) - def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: # noqa: D102 - # Based on column aliases that are tagged in this SELECT statement, tag corresponding column aliases in - # parent nodes. - self._cte_alias_to_cte_node.update({cte_source.cte_alias: cte_source for cte_source in node.cte_sources}) - - initial_required_column_aliases_in_this_node = self._column_alias_tagger.get_tagged_aliases(node) + def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: + """Based on required column aliases for this SELECT, figure out required column aliases in parents.""" + initial_required_column_aliases_in_this_node = self._current_required_column_alias_mapping.get_aliases(node) # If this SELECT statement uses DISTINCT, all columns are required as removing them would change the meaning of # the query. @@ -137,7 +153,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: # ) ) # Since additional select columns could have been selected due to DISTINCT or GROUP BY, re-tag. - self._column_alias_tagger.tag_aliases(node, updated_required_column_aliases_in_this_node) + self._current_required_column_alias_mapping.add_aliases(node, updated_required_column_aliases_in_this_node) required_select_columns_in_this_node = tuple( select_column @@ -145,12 +161,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: # if select_column.column_alias in updated_required_column_aliases_in_this_node ) - # TODO: don't prune columns used in join condition! Tricky to derive since the join condition can be any - # SqlExpressionNode. - if len(required_select_columns_in_this_node) == 0: raise RuntimeError( - "No columns are required in this node - this indicates a bug in this collector or in the inputs." + "No columns are required in this node - this indicates a bug in this visitor or in the inputs." ) # Based on the expressions in this select statement, figure out what column aliases are needed in the sources of @@ -161,26 +174,30 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: # # impossible to know what columns can be pruned from the parent sources. Tag all columns in parents as required. if any([string_expr.used_columns is None for string_expr in exprs_used_in_this_node.string_exprs]): for parent_node in node.parent_nodes: - self._column_alias_tagger.tag_all_aliases_in_node(parent_node) + nearest_select_columns = parent_node.nearest_select_columns() + for select_column in nearest_select_columns or (): + self._current_required_column_alias_mapping.add_alias(parent_node, select_column.column_alias) self._visit_parents(node) return # Create a mapping from the source alias to the column aliases needed from the corresponding source. - source_alias_to_required_column_alias: Dict[str, Set[str]] = defaultdict(set) + source_alias_to_required_column_aliases: Dict[str, Set[str]] = defaultdict(set) for column_reference_expr in exprs_used_in_this_node.column_reference_exprs: column_reference = column_reference_expr.col_ref - source_alias_to_required_column_alias[column_reference.table_alias].add(column_reference.column_name) + source_alias_to_required_column_aliases[column_reference.table_alias].add(column_reference.column_name) logger.debug( LazyFormat( "Collected required column names from sources", - source_alias_to_required_column_alias=source_alias_to_required_column_alias, + source_alias_to_required_column_aliases=source_alias_to_required_column_aliases, ) ) # Appropriately tag the columns required in the parent nodes. - if node.from_source_alias in source_alias_to_required_column_alias: - aliases_required_in_parent = source_alias_to_required_column_alias[node.from_source_alias] - self._column_alias_tagger.tag_aliases(node=node.from_source, column_aliases=aliases_required_in_parent) + if node.from_source_alias in source_alias_to_required_column_aliases: + aliases_required_in_parent = source_alias_to_required_column_aliases[node.from_source_alias] + self._current_required_column_alias_mapping.add_aliases( + node=node.from_source, column_aliases=aliases_required_in_parent + ) from_source_as_sql_table_node = node.from_source.as_sql_table_node if from_source_as_sql_table_node is not None: self._tag_potential_cte_node( @@ -189,9 +206,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: # ) for join_desc in node.join_descs: - if join_desc.right_source_alias in source_alias_to_required_column_alias: - aliases_required_in_parent = source_alias_to_required_column_alias[join_desc.right_source_alias] - self._column_alias_tagger.tag_aliases( + if join_desc.right_source_alias in source_alias_to_required_column_aliases: + aliases_required_in_parent = source_alias_to_required_column_aliases[join_desc.right_source_alias] + self._current_required_column_alias_mapping.add_aliases( node=join_desc.right_source, column_aliases=aliases_required_in_parent ) right_source_as_sql_table_node = join_desc.right_source.as_sql_table_node @@ -207,7 +224,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: # if string_expr.used_columns: for column_alias in string_expr.used_columns: for parent_node in node.parent_nodes: - self._column_alias_tagger.tag_alias(parent_node, column_alias) + self._current_required_column_alias_mapping.add_alias(parent_node, column_alias) # Same with unqualified column references - it's hard to tell which source it came from, so it's safest to say # it's required from all parents. @@ -216,7 +233,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: # for unqualified_column_reference_expr in exprs_used_in_this_node.column_alias_reference_exprs: column_alias = unqualified_column_reference_expr.column_alias for parent_node in node.parent_nodes: - self._column_alias_tagger.tag_alias(parent_node, column_alias) + self._current_required_column_alias_mapping.add_alias(parent_node, column_alias) # Visit recursively. self._visit_parents(node) @@ -232,3 +249,8 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> No def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> None: # noqa: D102 return self._visit_parents(node) + + @property + def required_column_alias_mapping(self) -> NodeToColumnAliasMapping: + """Return the column aliases required at each node as determined after traversal.""" + return self._current_required_column_alias_mapping diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index e6fb088a5..1757d2de0 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -51,6 +51,18 @@ def as_sql_table_node(self) -> Optional[SqlTableNode]: """If possible, return this as SQL table node.""" raise NotImplementedError + @abstractmethod + def nearest_select_columns(self) -> Optional[Sequence[SqlSelectColumn]]: + """Return the SELECT columns that are in this node or the closest `SqlSelectStatementNode` of its ancestors. + + * For a SELECT statement node, this is just the columns in the node. + * For a node that has a SELECT statement node as its only parent (e.g. CREATE TABLE ... AS SELECT ...), this + would be the SELECT columns in the parent. + * If not known (e.g. an arbitrary SQL statement as a string), return None. + * This is used to figure out which columns are needed at a leaf node of the DAG for column pruning. + """ + raise NotImplementedError + class SqlQueryPlanNodeVisitor(Generic[VisitorOutputT], ABC): """An object that can be used to visit the nodes of an SQL query. @@ -206,6 +218,10 @@ def as_sql_table_node(self) -> Optional[SqlTableNode]: def description(self) -> str: return self._description + @override + def nearest_select_columns(self) -> Sequence[SqlSelectColumn]: + return self.select_columns + @dataclass(frozen=True, eq=False) class SqlTableNode(SqlQueryPlanNode): @@ -244,6 +260,10 @@ def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102 def as_sql_table_node(self) -> Optional[SqlTableNode]: return self + @override + def nearest_select_columns(self) -> Optional[Sequence[SqlSelectColumn]]: + return None + @dataclass(frozen=True, eq=False) class SqlSelectQueryFromClauseNode(SqlQueryPlanNode): @@ -282,6 +302,10 @@ def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102 def as_sql_table_node(self) -> Optional[SqlTableNode]: return None + @override + def nearest_select_columns(self) -> Optional[Sequence[SqlSelectColumn]]: + return None + @dataclass(frozen=True, eq=False) class SqlCreateTableAsNode(SqlQueryPlanNode): @@ -332,6 +356,10 @@ def parent_node(self) -> SqlQueryPlanNode: # noqa: D102 def id_prefix(cls) -> IdPrefix: return StaticIdPrefix.SQL_PLAN_CREATE_TABLE_AS_ID_PREFIX + @override + def nearest_select_columns(self) -> Optional[Sequence[SqlSelectColumn]]: + return self.parent_node.nearest_select_columns() + class SqlQueryPlan(MetricFlowDag[SqlQueryPlanNode]): """Model for an SQL Query as a DAG.""" @@ -361,6 +389,10 @@ class SqlCteNode(SqlQueryPlanNode): select_statement: SqlQueryPlanNode cte_alias: str + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + @staticmethod def create(select_statement: SqlQueryPlanNode, cte_alias: str) -> SqlCteNode: # noqa: D102 return SqlCteNode( @@ -399,3 +431,11 @@ def description(self) -> str: @override def id_prefix(cls) -> IdPrefix: return StaticIdPrefix.SQL_PLAN_COMMON_TABLE_EXPRESSION_ID_PREFIX + + @property + def parent_node(self) -> SqlQueryPlanNode: # noqa: D102 + return self.parent_nodes[0] + + @override + def nearest_select_columns(self) -> Optional[Sequence[SqlSelectColumn]]: + return self.parent_node.nearest_select_columns()