Skip to content

Commit

Permalink
Verify column sets match in optimizer pushdown evaluation
Browse files Browse the repository at this point in the history
The PredicatePushdownOptimizer did not have access to the node
dataset resolver necessary to do the comparison between the linkable
specs referenced in the where filter and the linkable specs available
from the DataflowPlanNode targeted by predicate pushdown.

This change makes the node dataset resolver available in the optimizer.
It uses the one from the DataflowPlanBuilder in order to take advantage
of the cached resolutions available from the build process. The optimizer
is then used to evaluate the column matches in the same manner as the
original build-time pushdown evaluation.

This change was tested by running one of the predicate pushdown rendering
tests with the --log-cli-level=DEBUG confugration set, and observing the
debug output including the same entry for "Filter specs to add:" as in
the preceding commit.
  • Loading branch information
tlento committed Jun 13, 2024
1 parent 5a904ed commit fda9d9b
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 39 deletions.
2 changes: 1 addition & 1 deletion metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _build_plan(
return self._optimize_plan(plan, optimizations)

def _optimize_plan(self, plan: DataflowPlan, optimizations: Sequence[DataflowPlanOptimization]) -> DataflowPlan:
optimizer_factory = DataflowPlanOptimizerFactory()
optimizer_factory = DataflowPlanOptimizerFactory(self._node_data_set_resolver)
for optimizer in optimizer_factory.get_optimizers(optimizations):
logger.info(f"Applying {optimizer.__class__.__name__}")
try:
Expand Down
10 changes: 9 additions & 1 deletion metricflow/dataflow/optimizer/dataflow_optimizer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted

from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
from metricflow.dataflow.optimizer.dataflow_plan_optimizer import DataflowPlanOptimizer
from metricflow.dataflow.optimizer.predicate_pushdown_optimizer import PredicatePushdownOptimizer
from metricflow.dataflow.optimizer.source_scan.source_scan_optimizer import SourceScanOptimizer
Expand All @@ -24,14 +25,21 @@ class DataflowPlanOptimizerFactory:
processing between the DataflowPlanBuilder and the optimizer instances requiring that functionality.
"""

def __init__(self, node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver) -> None:
"""Initializer.
This collects all of the initialization requirements for the optimizers it manages.
"""
self._node_data_set_resolver = node_data_set_resolver

def get_optimizers(self, optimizations: Sequence[DataflowPlanOptimization]) -> Sequence[DataflowPlanOptimizer]:
"""Initializes and returns a sequence of optimizers matching the input optimization requests."""
optimizers: List[DataflowPlanOptimizer] = []
for optimization in optimizations:
if optimization is DataflowPlanOptimization.SOURCE_SCAN:
optimizers.append(SourceScanOptimizer())
elif optimization is DataflowPlanOptimization.PREDICATE_PUSHDOWN:
optimizers.append(PredicatePushdownOptimizer())
optimizers.append(PredicatePushdownOptimizer(self._node_data_set_resolver))
else:
assert_values_exhausted(optimization)

Expand Down
94 changes: 57 additions & 37 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import logging
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Iterator, List, Optional, Sequence, Tuple
from typing import Iterator, List, Optional, Sequence, Tuple, Union

from dbt_semantic_interfaces.references import SemanticModelReference
from metricflow_semantics.dag.id_prefix import StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagId
from metricflow_semantics.specs.spec_classes import WhereFilterSpec
from metricflow_semantics.sql.sql_join_type import SqlJoinType

from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
from metricflow.dataflow.dataflow_plan import (
DataflowPlan,
DataflowPlanNode,
Expand Down Expand Up @@ -104,13 +105,14 @@ class PredicatePushdownOptimizer(
constraint node if it is appropriate to do so.
"""

def __init__(self) -> None:
def __init__(self, node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver) -> None:
"""Initializer.
Initializes predicate pushdown state with all optimizer-managed pushdown types enabled, but nothing to
push down, since time range constraints and where filter specs will be discovered during traversal.
"""
self._log_level = logging.DEBUG
self._node_data_set_resolver = node_data_set_resolver
self._predicate_pushdown_tracker = PredicatePushdownBranchStateTracker(
initial_state=PredicatePushdownState(
time_range_constraint=None,
Expand Down Expand Up @@ -167,18 +169,53 @@ def _default_handler(
def _models_for_spec(self, spec: WhereFilterSpec) -> Sequence[SemanticModelReference]:
"""Return the distinct semantic models that source the elements referenced in the given where spec.
TODO: make this a property of the spec
TODO: Include special handling for entities, which can be sourced from multiple semantic models
"""
return tuple(set(element.semantic_model_origin for element in spec.linkable_elements))

# Source nodes - potential pushdown targets.

def visit_metric_time_dimension_transform_node( # noqa: D102
def visit_metric_time_dimension_transform_node(
self, node: MetricTimeDimensionTransformNode
) -> OptimizeBranchResult:
# TODO: Update docstring and logic to apply filter where needed. For now we simply add a logging
# hook with a superset of eligible filters we may consider applying here
"""Handles predicate pushdown operations against the MetricTimeDimensionTransformNode.
This node is the one where the metric_time column is constructed. As such, any where filter
targeting a measure input will be pushed to this node, and no further. In theory we could push
down directly to the ReadSqlSourceNode, but that requires some juggling on metric_time references
so we stop here for any matched filters. This shouldn't cause any meaningful problems, as the
SqlQueryPlanOptimizer processes typically collapse this and the underlying ReadSqlSourceNode into
a single subquery anyway.
As this is the base metric_time node, all time-based filter predicate pushdown needs to be managed
here.
"""
self._log_visit_node_type(node)
# TODO: Update to handle time range constraints
return self._push_down_where_filters(node)

def visit_source_node(self, node: ReadSqlSourceNode) -> OptimizeBranchResult:
"""Handles predicate pushdown to ReadSqlSourceNode.
This node is currently the root node in the dataflow DAG. In most cases, predicate pushdown will
stop with the MetricTimeDimensionTransformNode, but if there is ever a scenario where we do a
metric-free query with a one-sided outer join, an inner join, or a predicate filter set that can
be pushed past outer join boundaries (via some kind of semantic analysis or other semantic guarantee)
we want to make sure any applicable filters can be bound as closely to the input source as possible.
"""
self._log_visit_node_type(node)
return self._push_down_where_filters(node)

def _push_down_where_filters(
self, node: Union[MetricTimeDimensionTransformNode, ReadSqlSourceNode]
) -> OptimizeBranchResult:
"""Helper method for pushing where filters down to base source nodes.
This only accepts the two supported source node types - the ReadSqlSourceNode and the
MetricTimeDimensionTransformNode. In theory we could push down to ReadSqlSourceNode in every scenario, but
in practice this gets tricky given that filters on metric_time are expected, and metric_time is not a column
available on the original sql source node.
"""
current_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state
node_semantic_models = node.as_plan().source_semantic_models
if len(node_semantic_models) != 1 or not current_pushdown_state.has_where_filters_to_push_down:
Expand All @@ -187,13 +224,21 @@ def visit_metric_time_dimension_transform_node( # noqa: D102
source_semantic_model, *_ = node_semantic_models
filters_to_apply: List[WhereFilterSpec] = []
filters_left_over: List[WhereFilterSpec] = []
for spec in current_pushdown_state.where_filter_specs:
spec_semantic_models = self._models_for_spec(spec)
if len(spec_semantic_models) == 1 and spec_semantic_models[0] == source_semantic_model:
# TODO: check columns against the spec elements
filters_to_apply.append(spec)
source_node_linkable_specs = self._node_data_set_resolver.get_output_data_set(
node
).instance_set.spec_set.linkable_specs

for filter_spec in current_pushdown_state.where_filter_specs:
filter_spec_semantic_models = self._models_for_spec(filter_spec)
all_linkable_specs_match = all(spec in source_node_linkable_specs for spec in filter_spec.linkable_specs)
semantic_models_match = (
len(filter_spec_semantic_models) == 1 and filter_spec_semantic_models[0] == source_semantic_model
)
if all_linkable_specs_match and semantic_models_match:
filters_to_apply.append(filter_spec)
else:
filters_left_over.append(spec)
filters_left_over.append(filter_spec)

logger.log(level=self._log_level, msg=f"Filter specs to add:\n{filters_to_apply}")
# TODO: wrap node with a WhereConstraintNode and propagate filters applied back up the branch for removal
updated_pushdown_state = PredicatePushdownState(
Expand All @@ -203,31 +248,6 @@ def visit_metric_time_dimension_transform_node( # noqa: D102
)
return self._default_handler(node=node, pushdown_state=updated_pushdown_state)

def visit_source_node(self, node: ReadSqlSourceNode) -> OptimizeBranchResult: # noqa: D102
# TODO: Update docstring and logic to apply filters where needed.
# The commented out logic is a placeholder for what we'll use, ignore this for now. We only need to push
# down to this node type if we are doing a dimension-only query on a non-metric dataset, at least for now
#
# self._log_visit_node_type(node)
# current_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state
# node_semantic_models = node.as_plan().source_semantic_models
# if len(node_semantic_models) != 1 or current_pushdown_state.has_where_filters_to_push_down:
# return self._default_handler(node)

# source_semantic_model, *_ = node_semantic_models
# filters_to_apply: List[WhereFilterSpec] = []
# for spec in current_pushdown_state.where_filter_specs:
# spec_semantic_models = self._models_for_spec(spec)
# if len(spec_semantic_models) == 1 and spec_semantic_models[0] == source_semantic_model:
# # TODO: check columns against the spec elements
# filters_to_apply.append(spec)
# logger.log(level=self._log_level, msg=f"Filter specs to add:\n{filters_to_apply}")
# # TODO: wrap node with a WhereConstraintNode and propagate filters applied back up the branch for removal
# return self._default_handler(node)

self._log_visit_node_type(node)
return self._default_handler(node)

# Constraint nodes - predicate sources for pushdown.

def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> OptimizeBranchResult:
Expand Down

0 comments on commit fda9d9b

Please sign in to comment.