From 8d4d7fd9268d9e8c779ae63fa2da0554854fc89b Mon Sep 17 00:00:00 2001 From: tlento Date: Thu, 13 Jun 2024 10:45:55 -0700 Subject: [PATCH] Update API for requesting dataflow plan optimization In order to fully support predicate pushdown via the DataflowPlanOptimizer framework we need two things: 1. Support for optimization in distinct values queries 2. The ability to share components between the DataflowPlanBuilder and the PredicatePushdownOptimizer This update addresses both of these concerns by doing a small restructure of the DataflowPlanBuilder interface for accepting optimizers. Instead of accepting a sequence of optimizer instances, the build_plan method will now accept a sequence of optimization enumerations. Those will then be converted to instances via the factory class added in this change. From there the update to the distinct values plan method signature was a trivial addition. Note - snapshot updates should be limited to ID numbers due to the added call to the DataflowPlanNodeOutputDataSetResolver in the distinct values plan. --- .../dataflow/builder/dataflow_plan_builder.py | 30 ++++++++++----- .../optimizer/dataflow_optimizer_factory.py | 38 +++++++++++++++++++ metricflow/engine/metricflow_engine.py | 6 +-- .../query_rendering/compare_rendered_query.py | 8 ++-- ...ry_with_metric_filter__plan0_optimized.sql | 10 ++--- ...ry_with_metric_filter__plan0_optimized.sql | 10 ++--- ...ry_with_metric_filter__plan0_optimized.sql | 10 ++--- ...ry_with_metric_filter__plan0_optimized.sql | 10 ++--- ...ry_with_metric_filter__plan0_optimized.sql | 10 ++--- ...ry_with_metric_filter__plan0_optimized.sql | 10 ++--- ...ry_with_metric_filter__plan0_optimized.sql | 10 ++--- 11 files changed, 100 insertions(+), 52 deletions(-) create mode 100644 metricflow/dataflow/optimizer/dataflow_optimizer_factory.py diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 56af8825de..25be62f8ac 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -82,7 +82,10 @@ from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode -from metricflow.dataflow.optimizer.dataflow_plan_optimizer import DataflowPlanOptimizer +from metricflow.dataflow.optimizer.dataflow_optimizer_factory import ( + DataflowPlanOptimization, + DataflowPlanOptimizerFactory, +) from metricflow.dataset.dataset_classes import DataSet from metricflow.plan_conversion.node_processor import ( PredicateInputType, @@ -143,7 +146,7 @@ def build_plan( query_spec: MetricFlowQuerySpec, output_sql_table: Optional[SqlTable] = None, output_selection_specs: Optional[InstanceSpecSet] = None, - optimizers: Sequence[DataflowPlanOptimizer] = (), + optimizations: Sequence[DataflowPlanOptimization] = (), ) -> DataflowPlan: """Generate a plan for reading the results of a query with the given spec into a data_table or table.""" # Workaround for a Pycharm type inspection issue with decorators. @@ -152,7 +155,7 @@ def build_plan( query_spec=query_spec, output_sql_table=output_sql_table, output_selection_specs=output_selection_specs, - optimizers=optimizers, + optimizations=optimizations, ) def _build_query_output_node( @@ -208,7 +211,7 @@ def _build_plan( query_spec: MetricFlowQuerySpec, output_sql_table: Optional[SqlTable], output_selection_specs: Optional[InstanceSpecSet], - optimizers: Sequence[DataflowPlanOptimizer], + optimizations: Sequence[DataflowPlanOptimization], ) -> DataflowPlan: metrics_output_node = self._build_query_output_node(query_spec=query_spec) @@ -222,7 +225,11 @@ def _build_plan( plan_id = DagId.from_id_prefix(StaticIdPrefix.DATAFLOW_PLAN_PREFIX) plan = DataflowPlan(sink_nodes=[sink_node], plan_id=plan_id) - for optimizer in optimizers: + return self._optimize_plan(plan, optimizations) + + def _optimize_plan(self, plan: DataflowPlan, optimizations: Sequence[DataflowPlanOptimization]) -> DataflowPlan: + optimizer_factory = DataflowPlanOptimizerFactory() + for optimizer in optimizer_factory.get_optimizers(optimizations): logger.info(f"Applying {optimizer.__class__.__name__}") try: plan = optimizer.optimize(plan) @@ -733,17 +740,21 @@ def _build_metrics_output_node( return CombineAggregatedOutputsNode(parent_nodes=output_nodes) - def build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> DataflowPlan: + def build_plan_for_distinct_values( + self, query_spec: MetricFlowQuerySpec, optimizations: Sequence[DataflowPlanOptimization] = () + ) -> DataflowPlan: """Generate a plan that would get the distinct values of a linkable instance. e.g. distinct listing__country_latest for bookings by listing__country_latest """ # Workaround for a Pycharm type inspection issue with decorators. # noinspection PyArgumentList - return self._build_plan_for_distinct_values(query_spec) + return self._build_plan_for_distinct_values(query_spec, optimizations=optimizations) @log_runtime() - def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> DataflowPlan: + def _build_plan_for_distinct_values( + self, query_spec: MetricFlowQuerySpec, optimizations: Sequence[DataflowPlanOptimization] + ) -> DataflowPlan: assert not query_spec.metric_specs, "Can't build distinct values plan with metrics." query_level_filter_specs: Sequence[WhereFilterSpec] = () if query_spec.filter_intersection is not None and len(query_spec.filter_intersection.where_filters) > 0: @@ -792,7 +803,8 @@ def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Da parent_node=output_node, order_by_specs=query_spec.order_by_specs, limit=query_spec.limit ) - return DataflowPlan(sink_nodes=[sink_node]) + plan = DataflowPlan(sink_nodes=[sink_node]) + return self._optimize_plan(plan, optimizations) @staticmethod def build_sink_node( diff --git a/metricflow/dataflow/optimizer/dataflow_optimizer_factory.py b/metricflow/dataflow/optimizer/dataflow_optimizer_factory.py new file mode 100644 index 0000000000..57282e4137 --- /dev/null +++ b/metricflow/dataflow/optimizer/dataflow_optimizer_factory.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from enum import Enum +from typing import List, Sequence + +from dbt_semantic_interfaces.enum_extension import assert_values_exhausted + +from metricflow.dataflow.optimizer.dataflow_plan_optimizer import DataflowPlanOptimizer +from metricflow.dataflow.optimizer.predicate_pushdown_optimizer import PredicatePushdownOptimizer +from metricflow.dataflow.optimizer.source_scan.source_scan_optimizer import SourceScanOptimizer + + +class DataflowPlanOptimization(Enum): + """Enumeration of optimization types available for execution.""" + + PREDICATE_PUSHDOWN = "predicate_pushdown" + SOURCE_SCAN = "source_scan" + + +class DataflowPlanOptimizerFactory: + """Factory class for initializing an enumerated set of optimizers. + + This allows us to centralize initialization and, most importantly, share class instances with cached high cost + processing between the DataflowPlanBuilder and the optimizer instances requiring that functionality. + """ + + def get_optimizers(self, optimizations: Sequence[DataflowPlanOptimization]) -> Sequence[DataflowPlanOptimizer]: + """Initializes and returns a sequence of optimizers matching the input optimization requests.""" + optimizers: List[DataflowPlanOptimizer] = [] + for optimization in optimizations: + if optimization is DataflowPlanOptimization.SOURCE_SCAN: + optimizers.append(SourceScanOptimizer()) + elif optimization is DataflowPlanOptimization.PREDICATE_PUSHDOWN: + optimizers.append(PredicatePushdownOptimizer()) + else: + assert_values_exhausted(optimization) + + return tuple(optimizers) diff --git a/metricflow/engine/metricflow_engine.py b/metricflow/engine/metricflow_engine.py index dd35f97c9a..93a15657d0 100644 --- a/metricflow/engine/metricflow_engine.py +++ b/metricflow/engine/metricflow_engine.py @@ -39,9 +39,7 @@ from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver from metricflow.dataflow.builder.source_node import SourceNodeBuilder from metricflow.dataflow.dataflow_plan import DataflowPlan -from metricflow.dataflow.optimizer.source_scan.source_scan_optimizer import ( - SourceScanOptimizer, -) +from metricflow.dataflow.optimizer.dataflow_optimizer_factory import DataflowPlanOptimization from metricflow.dataset.convert_semantic_model import SemanticModelToDataSetConverter from metricflow.dataset.dataset_classes import DataSet from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet @@ -502,7 +500,7 @@ def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> Me dataflow_plan = self._dataflow_plan_builder.build_plan( query_spec=query_spec, output_selection_specs=output_selection_specs, - optimizers=(SourceScanOptimizer(),), + optimizations=(DataflowPlanOptimization.SOURCE_SCAN,), ) else: dataflow_plan = self._dataflow_plan_builder.build_plan_for_distinct_values(query_spec=query_spec) diff --git a/tests_metricflow/query_rendering/compare_rendered_query.py b/tests_metricflow/query_rendering/compare_rendered_query.py index 83d2feda1b..d99a019bf7 100644 --- a/tests_metricflow/query_rendering/compare_rendered_query.py +++ b/tests_metricflow/query_rendering/compare_rendered_query.py @@ -8,7 +8,7 @@ from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder -from metricflow.dataflow.optimizer.predicate_pushdown_optimizer import PredicatePushdownOptimizer +from metricflow.dataflow.optimizer.dataflow_optimizer_factory import DataflowPlanOptimization from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter from metricflow.protocols.sql_client import SqlClient from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel @@ -52,11 +52,11 @@ def render_and_check( ) # Run dataflow -> sql conversion with all optimizers + optimizations = (DataflowPlanOptimization.PREDICATE_PUSHDOWN,) if is_distinct_values_plan: - # TODO: Make optimization available for distinct values plans - optimized_plan = base_plan + optimized_plan = dataflow_plan_builder.build_plan_for_distinct_values(query_spec, optimizations=optimizations) else: - optimized_plan = dataflow_plan_builder.build_plan(query_spec, optimizers=(PredicatePushdownOptimizer(),)) + optimized_plan = dataflow_plan_builder.build_plan(query_spec, optimizations=optimizations) conversion_result = dataflow_to_sql_converter.convert_to_sql_query_plan( sql_engine_type=sql_client.sql_engine_type, dataflow_plan_node=optimized_plan.sink_node, diff --git a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/BigQuery/test_distinct_values_query_with_metric_filter__plan0_optimized.sql b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/BigQuery/test_distinct_values_query_with_metric_filter__plan0_optimized.sql index fc77fa7686..cde9c960c5 100644 --- a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/BigQuery/test_distinct_values_query_with_metric_filter__plan0_optimized.sql +++ b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/BigQuery/test_distinct_values_query_with_metric_filter__plan0_optimized.sql @@ -6,7 +6,7 @@ FROM ( -- Join Standard Outputs SELECT lux_listing_mapping_src_28000.listing_id AS listing - , subq_19.listing__bookings AS listing__bookings + , subq_23.listing__bookings AS listing__bookings FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000 FULL OUTER JOIN ( -- Aggregate Measures @@ -23,13 +23,13 @@ FROM ( listing_id AS listing , 1 AS bookings FROM ***************************.fct_bookings bookings_source_src_28000 - ) subq_16 + ) subq_20 GROUP BY listing - ) subq_19 + ) subq_23 ON - lux_listing_mapping_src_28000.listing_id = subq_19.listing -) subq_20 + lux_listing_mapping_src_28000.listing_id = subq_23.listing +) subq_24 WHERE listing__bookings > 2 GROUP BY listing diff --git a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Databricks/test_distinct_values_query_with_metric_filter__plan0_optimized.sql b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Databricks/test_distinct_values_query_with_metric_filter__plan0_optimized.sql index fc77fa7686..cde9c960c5 100644 --- a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Databricks/test_distinct_values_query_with_metric_filter__plan0_optimized.sql +++ b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Databricks/test_distinct_values_query_with_metric_filter__plan0_optimized.sql @@ -6,7 +6,7 @@ FROM ( -- Join Standard Outputs SELECT lux_listing_mapping_src_28000.listing_id AS listing - , subq_19.listing__bookings AS listing__bookings + , subq_23.listing__bookings AS listing__bookings FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000 FULL OUTER JOIN ( -- Aggregate Measures @@ -23,13 +23,13 @@ FROM ( listing_id AS listing , 1 AS bookings FROM ***************************.fct_bookings bookings_source_src_28000 - ) subq_16 + ) subq_20 GROUP BY listing - ) subq_19 + ) subq_23 ON - lux_listing_mapping_src_28000.listing_id = subq_19.listing -) subq_20 + lux_listing_mapping_src_28000.listing_id = subq_23.listing +) subq_24 WHERE listing__bookings > 2 GROUP BY listing diff --git a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/DuckDB/test_distinct_values_query_with_metric_filter__plan0_optimized.sql b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/DuckDB/test_distinct_values_query_with_metric_filter__plan0_optimized.sql index fc77fa7686..cde9c960c5 100644 --- a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/DuckDB/test_distinct_values_query_with_metric_filter__plan0_optimized.sql +++ b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/DuckDB/test_distinct_values_query_with_metric_filter__plan0_optimized.sql @@ -6,7 +6,7 @@ FROM ( -- Join Standard Outputs SELECT lux_listing_mapping_src_28000.listing_id AS listing - , subq_19.listing__bookings AS listing__bookings + , subq_23.listing__bookings AS listing__bookings FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000 FULL OUTER JOIN ( -- Aggregate Measures @@ -23,13 +23,13 @@ FROM ( listing_id AS listing , 1 AS bookings FROM ***************************.fct_bookings bookings_source_src_28000 - ) subq_16 + ) subq_20 GROUP BY listing - ) subq_19 + ) subq_23 ON - lux_listing_mapping_src_28000.listing_id = subq_19.listing -) subq_20 + lux_listing_mapping_src_28000.listing_id = subq_23.listing +) subq_24 WHERE listing__bookings > 2 GROUP BY listing diff --git a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Postgres/test_distinct_values_query_with_metric_filter__plan0_optimized.sql b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Postgres/test_distinct_values_query_with_metric_filter__plan0_optimized.sql index fc77fa7686..cde9c960c5 100644 --- a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Postgres/test_distinct_values_query_with_metric_filter__plan0_optimized.sql +++ b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Postgres/test_distinct_values_query_with_metric_filter__plan0_optimized.sql @@ -6,7 +6,7 @@ FROM ( -- Join Standard Outputs SELECT lux_listing_mapping_src_28000.listing_id AS listing - , subq_19.listing__bookings AS listing__bookings + , subq_23.listing__bookings AS listing__bookings FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000 FULL OUTER JOIN ( -- Aggregate Measures @@ -23,13 +23,13 @@ FROM ( listing_id AS listing , 1 AS bookings FROM ***************************.fct_bookings bookings_source_src_28000 - ) subq_16 + ) subq_20 GROUP BY listing - ) subq_19 + ) subq_23 ON - lux_listing_mapping_src_28000.listing_id = subq_19.listing -) subq_20 + lux_listing_mapping_src_28000.listing_id = subq_23.listing +) subq_24 WHERE listing__bookings > 2 GROUP BY listing diff --git a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Redshift/test_distinct_values_query_with_metric_filter__plan0_optimized.sql b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Redshift/test_distinct_values_query_with_metric_filter__plan0_optimized.sql index fc77fa7686..cde9c960c5 100644 --- a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Redshift/test_distinct_values_query_with_metric_filter__plan0_optimized.sql +++ b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Redshift/test_distinct_values_query_with_metric_filter__plan0_optimized.sql @@ -6,7 +6,7 @@ FROM ( -- Join Standard Outputs SELECT lux_listing_mapping_src_28000.listing_id AS listing - , subq_19.listing__bookings AS listing__bookings + , subq_23.listing__bookings AS listing__bookings FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000 FULL OUTER JOIN ( -- Aggregate Measures @@ -23,13 +23,13 @@ FROM ( listing_id AS listing , 1 AS bookings FROM ***************************.fct_bookings bookings_source_src_28000 - ) subq_16 + ) subq_20 GROUP BY listing - ) subq_19 + ) subq_23 ON - lux_listing_mapping_src_28000.listing_id = subq_19.listing -) subq_20 + lux_listing_mapping_src_28000.listing_id = subq_23.listing +) subq_24 WHERE listing__bookings > 2 GROUP BY listing diff --git a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Snowflake/test_distinct_values_query_with_metric_filter__plan0_optimized.sql b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Snowflake/test_distinct_values_query_with_metric_filter__plan0_optimized.sql index fc77fa7686..cde9c960c5 100644 --- a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Snowflake/test_distinct_values_query_with_metric_filter__plan0_optimized.sql +++ b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Snowflake/test_distinct_values_query_with_metric_filter__plan0_optimized.sql @@ -6,7 +6,7 @@ FROM ( -- Join Standard Outputs SELECT lux_listing_mapping_src_28000.listing_id AS listing - , subq_19.listing__bookings AS listing__bookings + , subq_23.listing__bookings AS listing__bookings FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000 FULL OUTER JOIN ( -- Aggregate Measures @@ -23,13 +23,13 @@ FROM ( listing_id AS listing , 1 AS bookings FROM ***************************.fct_bookings bookings_source_src_28000 - ) subq_16 + ) subq_20 GROUP BY listing - ) subq_19 + ) subq_23 ON - lux_listing_mapping_src_28000.listing_id = subq_19.listing -) subq_20 + lux_listing_mapping_src_28000.listing_id = subq_23.listing +) subq_24 WHERE listing__bookings > 2 GROUP BY listing diff --git a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Trino/test_distinct_values_query_with_metric_filter__plan0_optimized.sql b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Trino/test_distinct_values_query_with_metric_filter__plan0_optimized.sql index fc77fa7686..cde9c960c5 100644 --- a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Trino/test_distinct_values_query_with_metric_filter__plan0_optimized.sql +++ b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Trino/test_distinct_values_query_with_metric_filter__plan0_optimized.sql @@ -6,7 +6,7 @@ FROM ( -- Join Standard Outputs SELECT lux_listing_mapping_src_28000.listing_id AS listing - , subq_19.listing__bookings AS listing__bookings + , subq_23.listing__bookings AS listing__bookings FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000 FULL OUTER JOIN ( -- Aggregate Measures @@ -23,13 +23,13 @@ FROM ( listing_id AS listing , 1 AS bookings FROM ***************************.fct_bookings bookings_source_src_28000 - ) subq_16 + ) subq_20 GROUP BY listing - ) subq_19 + ) subq_23 ON - lux_listing_mapping_src_28000.listing_id = subq_19.listing -) subq_20 + lux_listing_mapping_src_28000.listing_id = subq_23.listing +) subq_24 WHERE listing__bookings > 2 GROUP BY listing