Skip to content

Commit

Permalink
Add direct tests for predicate pushdown optimizer (#1303)
Browse files Browse the repository at this point in the history
This addresses the testing gap around the cases covered by the
logic internal to the PredicatePushdownOptimizer.

It's rather difficult to hand-construct dataflow plans, so these
rely on snapshots of DataflowPlan text output for some carefully
selected queries which are known to cover the code paths in question.
  • Loading branch information
tlento authored Jun 26, 2024
1 parent d61835b commit 69feadc
Show file tree
Hide file tree
Showing 19 changed files with 5,457 additions and 3 deletions.
5 changes: 5 additions & 0 deletions metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ def __init__(self, sink_nodes: Sequence[DataflowPlanNode], plan_id: Optional[Dag
def sink_node(self) -> DataflowPlanNode: # noqa: D102
return self._sink_nodes[0]

@property
def node_count(self) -> int:
"""Returns the number of nodes in the DataflowPlan."""
return len(DataflowPlan.__all_nodes_in_subgraph(self.sink_node))

@staticmethod
def __all_nodes_in_subgraph(node: DataflowPlanNode) -> Sequence[DataflowPlanNode]:
"""Node accessor for retrieving a flattened sequence of all nodes in the subgraph upstream of the input node.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,6 @@ def visit_combine_aggregated_outputs_node(self, node: CombineAggregatedOutputsNo
handling this scenario at this time.
"""
self._log_visit_node_type(node)
# TODO: move this "remove where filters" logic into PredicatePushdownState
updated_pushdown_state = PredicatePushdownState.without_where_filter_specs(
original_pushdown_state=self._predicate_pushdown_tracker.last_pushdown_state,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,23 @@
from datetime import datetime

import pytest
from _pytest.fixtures import FixtureRequest
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter
from metricflow_semantics.filters.time_constraint import TimeRangeConstraint
from metricflow_semantics.specs.spec_classes import WhereFilterSpec
from metricflow_semantics.query.query_parser import MetricFlowQueryParser
from metricflow_semantics.specs.query_spec import MetricFlowQuerySpec
from metricflow_semantics.specs.spec_classes import (
WhereFilterSpec,
)
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration
from metricflow_semantics.test_helpers.snapshot_helpers import assert_plan_snapshot_text_equal

from metricflow.dataflow.optimizer.predicate_pushdown_optimizer import PredicatePushdownBranchStateTracker
from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder
from metricflow.dataflow.optimizer.predicate_pushdown_optimizer import (
PredicatePushdownBranchStateTracker,
PredicatePushdownOptimizer,
)
from metricflow.plan_conversion.node_processor import PredicateInputType, PredicatePushdownState


Expand Down Expand Up @@ -148,3 +160,238 @@ def test_applied_filter_back_propagation(branch_state_tracker: PredicatePushdown

# We expect to propagate back to the initial entry since we only ever want to apply a filter once within a branch
assert branch_state_tracker.last_pushdown_state == both_applied_state


def _check_optimization(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
dataflow_plan_builder: DataflowPlanBuilder,
query_spec: MetricFlowQuerySpec,
expected_additional_constraint_nodes_in_optimized: int,
) -> None:
dataflow_plan = dataflow_plan_builder.build_plan(query_spec=query_spec)
optimizer = PredicatePushdownOptimizer(node_data_set_resolver=dataflow_plan_builder._node_data_set_resolver)
optimized_plan = optimizer.optimize(dataflow_plan=dataflow_plan)

for plan in (dataflow_plan, optimized_plan):
assert_plan_snapshot_text_equal(
request=request,
mf_test_configuration=mf_test_configuration,
plan=plan,
plan_snapshot_text=plan.structure_text(),
)

assert dataflow_plan.node_count + expected_additional_constraint_nodes_in_optimized == optimized_plan.node_count, (
f"Did not get the expected number ({expected_additional_constraint_nodes_in_optimized}) of additional "
f"constraint nodes in the optimized plan, found {optimized_plan.node_count - dataflow_plan.node_count} added "
"nodes. Check snapshot output for details."
)


def test_simple_join_categorical_pushdown(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
query_parser: MetricFlowQueryParser,
dataflow_plan_builder: DataflowPlanBuilder,
) -> None:
"""Tests pushdown optimization for a simple predicate through a single join."""
query_spec = query_parser.parse_and_validate_query(
metric_names=("bookings",),
group_by_names=("listing__country_latest",),
where_constraint=PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}"),
).query_spec
_check_optimization(
request=request,
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1,
)


def test_simple_join_metric_time_pushdown_with_two_targets(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
query_parser: MetricFlowQueryParser,
dataflow_plan_builder: DataflowPlanBuilder,
) -> None:
"""Tests pushdown optimization for a simple metric time predicate through a single join.
This includes a scenario where the dimension source is also a metric time node, but we do NOT want the metric_time
filter applied to it since it is a _current style dimension table at its core.
Note this optimizer will not push the predicate down until metric_time pushdown is supported.
"""
query_spec = query_parser.parse_and_validate_query(
metric_names=("bookings",),
group_by_names=("listing__country_latest",),
where_constraint=PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time') }} = '2024-01-01'"),
).query_spec
_check_optimization(
request=request,
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=0, # TODO: Add support for time dimension pushdown
)


def test_conversion_metric_predicate_pushdown(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
query_parser: MetricFlowQueryParser,
dataflow_plan_builder: DataflowPlanBuilder,
) -> None:
"""Tests pushdown optimizer behavior for a simple predicate on a conversion metric.
As of this time the pushdown should NOT move past the conversion metric node.
"""
query_spec = query_parser.parse_and_validate_query(
metric_names=("visit_buy_conversion_rate_7days",),
group_by_names=("metric_time", "user__home_state_latest"),
where_constraint=PydanticWhereFilter(where_sql_template="{{ Dimension('visit__referrer_id') }} = '123456'"),
).query_spec
_check_optimization(
request=request,
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1, # TODO: Remove superfluous where constraint nodes
)


def test_cumulative_metric_predicate_pushdown(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
query_parser: MetricFlowQueryParser,
dataflow_plan_builder: DataflowPlanBuilder,
) -> None:
"""Tests pushdown optimizer behavior for a query against a cumulative metric.
At this time categorical dimension predicates should be pushed down, but metric_time predicates should not be,
since supporting time filter pushdown for cumulative metrics requires filter expansion to ensure we capture the
full set of inputs for the initial cumulative window.
TODO: Add metric time filters
"""
query_spec = query_parser.parse_and_validate_query(
metric_names=("every_two_days_bookers",),
group_by_names=("listing__country_latest", "metric_time"),
where_constraint=PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}"),
).query_spec
_check_optimization(
request=request,
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1,
)


@pytest.mark.skip("plan output has non-deterministic ordering")
def test_aggregate_output_join_metric_predicate_pushdown(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
query_parser: MetricFlowQueryParser,
dataflow_plan_builder: DataflowPlanBuilder,
) -> None:
"""Tests pushdown optimizer behavior when a metric does an aggregate output metric join.
In this case we expect filters to not be pushed down, since they are outside of a full outer join.
"""
query_spec = query_parser.parse_and_validate_query(
metric_names=("views_times_booking_value",),
where_constraint=PydanticWhereFilter(where_sql_template="{{ Dimension('listing__is_lux_latest') }}"),
).query_spec
_check_optimization(
request=request,
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=0,
)


def test_offset_metric_predicate_pushdown(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
query_parser: MetricFlowQueryParser,
dataflow_plan_builder: DataflowPlanBuilder,
) -> None:
"""Tests pushdown optimizer behavior for a query against a derived offset metric.
As with cumulative metrics, at this time categorical dimension predicates may be pushed down, but metric_time
predicates should not be, since we need to capture the union of the filter window and the offset span.
TODO: Add metric time filters
"""
query_spec = query_parser.parse_and_validate_query(
metric_names=("bookings_growth_2_weeks",),
group_by_names=("listing__country_latest", "metric_time"),
where_constraint=PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}"),
).query_spec
_check_optimization(
request=request,
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1,
)


def test_fill_nulls_time_spine_metric_predicate_pushdown(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
query_parser: MetricFlowQueryParser,
dataflow_plan_builder: DataflowPlanBuilder,
) -> None:
"""Tests pushdown optimizer behavior for a metric with a time spine and fill_nulls_with enabled.
Until time dimension pushdown is supported we will only see the categorical dimension entry pushed down here.
TODO: Add metric time filters
"""
query_spec = query_parser.parse_and_validate_query(
metric_names=("bookings_growth_2_weeks_fill_nulls_with_0",),
group_by_names=("listing__country_latest", "metric_time"),
where_constraint=PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}"),
).query_spec
_check_optimization(
request=request,
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1,
)


def test_fill_nulls_time_spine_metric_with_post_agg_join_predicate_pushdown(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
query_parser: MetricFlowQueryParser,
dataflow_plan_builder: DataflowPlanBuilder,
) -> None:
"""Tests pushdown optimizer behavior for a metric with a time spine and fill_nulls_with and a post-agg join.
When querying a metric like this with a group by on all filter specs we do a post-aggregation outer join
against the time spine, which should preclude predicate pushdown for query-time filters at that state, but
will allow for pushdown within the JoinToTimeSpine operation. This will still do predicate pushdown as before,
but only exactly as before - the added constraint outside of the JoinToTimeSpine operation must still be
applied.
Until time dimension pushdown is supported we will only see the categorical dimension entry pushed down here.
TODO: Add metric time filters
"""
query_spec = query_parser.parse_and_validate_query(
metric_names=("bookings_growth_2_weeks_fill_nulls_with_0",),
group_by_names=("listing__country_latest", "booking__is_instant", "metric_time"),
where_constraint=PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}"),
).query_spec
_check_optimization(
request=request,
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1,
)
Loading

0 comments on commit 69feadc

Please sign in to comment.