Skip to content

Commit

Permalink
Rename to NodeToColumnAliasMapping.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 8, 2024
1 parent 37d8c5b commit 69e5feb
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 129 deletions.
44 changes: 26 additions & 18 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import logging
from typing import FrozenSet, Mapping

from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.optimizer.tag_column_aliases import TaggedColumnAliasSet
from metricflow.sql.optimizer.tag_required_column_aliases import SqlTagRequiredColumnAliasesVisitor
from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping
from metricflow.sql.optimizer.tag_required_column_aliases import SqlMapRequiredColumnAliasesVisitor
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
Expand All @@ -29,7 +29,7 @@ class SqlColumnPrunerVisitor(SqlQueryPlanNodeVisitor[SqlQueryPlanNode]):

def __init__(
self,
required_alias_mapping: Mapping[SqlQueryPlanNode, FrozenSet[str]],
required_alias_mapping: NodeToColumnAliasMapping,
) -> None:
"""Constructor.
Expand All @@ -42,7 +42,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
# Remove columns that are not needed from this SELECT statement because the parent SELECT statement doesn't
# need them. However, keep columns that are in group bys because that changes the meaning of the query.
# Similarly, if this node is a distinct select node, keep all columns as it may return a different result set.
required_column_aliases = self._required_alias_mapping.get(node)
required_column_aliases = self._required_alias_mapping.get_aliases(node)
if required_column_aliases is None:
logger.error(
f"Did not find {node.node_id=} in the required alias mapping. Returning the non-pruned version "
Expand Down Expand Up @@ -101,23 +101,31 @@ def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:


class SqlColumnPrunerOptimizer(SqlQueryPlanOptimizer):
"""Removes unnecessary columns in the SELECT clauses."""
"""Removes unnecessary columns in the SELECT statements."""

def optimize(self, node: SqlQueryPlanNode) -> SqlQueryPlanNode: # noqa: D102
# Can't prune columns without knowing the structure of the query.
if not node.as_select_node:
# ALl columns in the nearest SELECT node need to be kept as otherwise, the meaning of the query changes.
required_select_columns = node.nearest_select_columns()

# Can't prune without knowing the structure of the query.
if required_select_columns is None:
logger.debug(
LazyFormat(
"The columns required at this node can't be determined, so skipping column pruning",
node=node.structure_text(),
required_select_columns=required_select_columns,
)
)
return node

# Figure out which columns in which nodes are required.
tagged_column_alias_set = TaggedColumnAliasSet()
tagged_column_alias_set.tag_all_aliases_in_node(node.as_select_node)
tag_required_column_alias_visitor = SqlTagRequiredColumnAliasesVisitor(
tagged_column_alias_set=tagged_column_alias_set,
map_required_column_aliases_visitor = SqlMapRequiredColumnAliasesVisitor(
start_node=node,
required_column_aliases_in_start_node=frozenset(
[select_column.column_alias for select_column in required_select_columns]
),
)
node.accept(tag_required_column_alias_visitor)
node.accept(map_required_column_aliases_visitor)

# Re-write the query, pruning columns in the SELECT that are not needed.
pruning_visitor = SqlColumnPrunerVisitor(
required_alias_mapping=tagged_column_alias_set.get_mapping(),
)
# Re-write the query, removing unnecessary columns in the SELECT statements.
pruning_visitor = SqlColumnPrunerVisitor(map_required_column_aliases_visitor.required_column_alias_mapping)
return node.accept(pruning_visitor)
83 changes: 10 additions & 73 deletions metricflow/sql/optimizer/tag_column_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,95 +2,32 @@

import logging
from collections import defaultdict
from typing import Dict, FrozenSet, Iterable, Mapping, Set

from typing_extensions import override
from typing import Dict, FrozenSet, Iterable, Set

from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlQueryPlanNode,
SqlQueryPlanNodeVisitor,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
)

logger = logging.getLogger(__name__)


class TaggedColumnAliasSet:
"""Keep track of column aliases in SELECT statements that have been tagged.
The main use case for this class is to keep track of column aliases / columns that are required so that unnecessary
columns can be pruned.
For example, in this query:
SELECT source_0.col_0 AS col_0
FROM (
SELECT
example_table.col_0
example_table.col_1
FROM example_table
) source_0
class NodeToColumnAliasMapping:
"""Mutable class for mapping a SQL node to an arbitrary set of column aliases for that node.
this class can be used to tag `example_table.col_0` but not tag `example_table.col_1` since it's not needed for the
query to run correctly.
* Alternatively, this can be described as mapping a location in the SQL query plan to a set of column aliases.
* See `SqlMapRequiredColumnAliasesVisitor` for the main use case for this class.
* This is a thin wrapper over a dict to aid readability.
"""

def __init__(self) -> None: # noqa: D107
self._node_to_tagged_aliases: Dict[SqlQueryPlanNode, Set[str]] = defaultdict(set)

def get_tagged_aliases(self, node: SqlQueryPlanNode) -> FrozenSet[str]:
"""Return the given tagged column aliases associated with the given SQL node."""
def get_aliases(self, node: SqlQueryPlanNode) -> FrozenSet[str]:
"""Return the column aliases added for the given SQL node."""
return frozenset(self._node_to_tagged_aliases[node])

def tag_alias(self, node: SqlQueryPlanNode, column_alias: str) -> None: # noqa: D102
def add_alias(self, node: SqlQueryPlanNode, column_alias: str) -> None: # noqa: D102
return self._node_to_tagged_aliases[node].add(column_alias)

def tag_aliases(self, node: SqlQueryPlanNode, column_aliases: Iterable[str]) -> None: # noqa: D102
def add_aliases(self, node: SqlQueryPlanNode, column_aliases: Iterable[str]) -> None: # noqa: D102
self._node_to_tagged_aliases[node].update(column_aliases)

def tag_all_aliases_in_node(self, node: SqlQueryPlanNode) -> None:
"""Convenience method that tags all column aliases in the given SQL node, where appropriate."""
node.accept(_TagAllColumnAliasesInNodeVisitor(self))

def get_mapping(self) -> Mapping[SqlQueryPlanNode, FrozenSet[str]]:
"""Return mapping view that associates a given SQL node with the tagged column aliases in that node."""
return {node: frozenset(tagged_aliases) for node, tagged_aliases in self._node_to_tagged_aliases.items()}


class _TagAllColumnAliasesInNodeVisitor(SqlQueryPlanNodeVisitor[None]):
"""Visitor to help implement `TaggedColumnAliasSet.tag_all_aliases_in_node`."""

def __init__(self, required_column_alias_collector: TaggedColumnAliasSet) -> None:
self._required_column_alias_collector = required_column_alias_collector

@override
def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
for select_column in node.select_columns:
self._required_column_alias_collector.tag_alias(
node=node,
column_alias=select_column.column_alias,
)

@override
def visit_table_node(self, node: SqlTableNode) -> None:
"""Columns in a SQL table are not represented."""
pass

@override
def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> None:
"""Columns in an arbitrary SQL query are not represented."""
pass

@override
def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> None:
for parent_node in node.parent_nodes:
parent_node.accept(self)

@override
def visit_cte_node(self, node: SqlCteNode) -> None:
for parent_node in node.parent_nodes:
parent_node.accept(self)
Loading

0 comments on commit 69e5feb

Please sign in to comment.