diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 56af8825de..25be62f8ac 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -82,7 +82,10 @@ from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode -from metricflow.dataflow.optimizer.dataflow_plan_optimizer import DataflowPlanOptimizer +from metricflow.dataflow.optimizer.dataflow_optimizer_factory import ( + DataflowPlanOptimization, + DataflowPlanOptimizerFactory, +) from metricflow.dataset.dataset_classes import DataSet from metricflow.plan_conversion.node_processor import ( PredicateInputType, @@ -143,7 +146,7 @@ def build_plan( query_spec: MetricFlowQuerySpec, output_sql_table: Optional[SqlTable] = None, output_selection_specs: Optional[InstanceSpecSet] = None, - optimizers: Sequence[DataflowPlanOptimizer] = (), + optimizations: Sequence[DataflowPlanOptimization] = (), ) -> DataflowPlan: """Generate a plan for reading the results of a query with the given spec into a data_table or table.""" # Workaround for a Pycharm type inspection issue with decorators. @@ -152,7 +155,7 @@ def build_plan( query_spec=query_spec, output_sql_table=output_sql_table, output_selection_specs=output_selection_specs, - optimizers=optimizers, + optimizations=optimizations, ) def _build_query_output_node( @@ -208,7 +211,7 @@ def _build_plan( query_spec: MetricFlowQuerySpec, output_sql_table: Optional[SqlTable], output_selection_specs: Optional[InstanceSpecSet], - optimizers: Sequence[DataflowPlanOptimizer], + optimizations: Sequence[DataflowPlanOptimization], ) -> DataflowPlan: metrics_output_node = self._build_query_output_node(query_spec=query_spec) @@ -222,7 +225,11 @@ def _build_plan( plan_id = DagId.from_id_prefix(StaticIdPrefix.DATAFLOW_PLAN_PREFIX) plan = DataflowPlan(sink_nodes=[sink_node], plan_id=plan_id) - for optimizer in optimizers: + return self._optimize_plan(plan, optimizations) + + def _optimize_plan(self, plan: DataflowPlan, optimizations: Sequence[DataflowPlanOptimization]) -> DataflowPlan: + optimizer_factory = DataflowPlanOptimizerFactory() + for optimizer in optimizer_factory.get_optimizers(optimizations): logger.info(f"Applying {optimizer.__class__.__name__}") try: plan = optimizer.optimize(plan) @@ -733,17 +740,21 @@ def _build_metrics_output_node( return CombineAggregatedOutputsNode(parent_nodes=output_nodes) - def build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> DataflowPlan: + def build_plan_for_distinct_values( + self, query_spec: MetricFlowQuerySpec, optimizations: Sequence[DataflowPlanOptimization] = () + ) -> DataflowPlan: """Generate a plan that would get the distinct values of a linkable instance. e.g. distinct listing__country_latest for bookings by listing__country_latest """ # Workaround for a Pycharm type inspection issue with decorators. # noinspection PyArgumentList - return self._build_plan_for_distinct_values(query_spec) + return self._build_plan_for_distinct_values(query_spec, optimizations=optimizations) @log_runtime() - def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> DataflowPlan: + def _build_plan_for_distinct_values( + self, query_spec: MetricFlowQuerySpec, optimizations: Sequence[DataflowPlanOptimization] + ) -> DataflowPlan: assert not query_spec.metric_specs, "Can't build distinct values plan with metrics." query_level_filter_specs: Sequence[WhereFilterSpec] = () if query_spec.filter_intersection is not None and len(query_spec.filter_intersection.where_filters) > 0: @@ -792,7 +803,8 @@ def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Da parent_node=output_node, order_by_specs=query_spec.order_by_specs, limit=query_spec.limit ) - return DataflowPlan(sink_nodes=[sink_node]) + plan = DataflowPlan(sink_nodes=[sink_node]) + return self._optimize_plan(plan, optimizations) @staticmethod def build_sink_node( diff --git a/metricflow/dataflow/optimizer/dataflow_optimizer_factory.py b/metricflow/dataflow/optimizer/dataflow_optimizer_factory.py new file mode 100644 index 0000000000..57282e4137 --- /dev/null +++ b/metricflow/dataflow/optimizer/dataflow_optimizer_factory.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from enum import Enum +from typing import List, Sequence + +from dbt_semantic_interfaces.enum_extension import assert_values_exhausted + +from metricflow.dataflow.optimizer.dataflow_plan_optimizer import DataflowPlanOptimizer +from metricflow.dataflow.optimizer.predicate_pushdown_optimizer import PredicatePushdownOptimizer +from metricflow.dataflow.optimizer.source_scan.source_scan_optimizer import SourceScanOptimizer + + +class DataflowPlanOptimization(Enum): + """Enumeration of optimization types available for execution.""" + + PREDICATE_PUSHDOWN = "predicate_pushdown" + SOURCE_SCAN = "source_scan" + + +class DataflowPlanOptimizerFactory: + """Factory class for initializing an enumerated set of optimizers. + + This allows us to centralize initialization and, most importantly, share class instances with cached high cost + processing between the DataflowPlanBuilder and the optimizer instances requiring that functionality. + """ + + def get_optimizers(self, optimizations: Sequence[DataflowPlanOptimization]) -> Sequence[DataflowPlanOptimizer]: + """Initializes and returns a sequence of optimizers matching the input optimization requests.""" + optimizers: List[DataflowPlanOptimizer] = [] + for optimization in optimizations: + if optimization is DataflowPlanOptimization.SOURCE_SCAN: + optimizers.append(SourceScanOptimizer()) + elif optimization is DataflowPlanOptimization.PREDICATE_PUSHDOWN: + optimizers.append(PredicatePushdownOptimizer()) + else: + assert_values_exhausted(optimization) + + return tuple(optimizers) diff --git a/metricflow/engine/metricflow_engine.py b/metricflow/engine/metricflow_engine.py index dd35f97c9a..93a15657d0 100644 --- a/metricflow/engine/metricflow_engine.py +++ b/metricflow/engine/metricflow_engine.py @@ -39,9 +39,7 @@ from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver from metricflow.dataflow.builder.source_node import SourceNodeBuilder from metricflow.dataflow.dataflow_plan import DataflowPlan -from metricflow.dataflow.optimizer.source_scan.source_scan_optimizer import ( - SourceScanOptimizer, -) +from metricflow.dataflow.optimizer.dataflow_optimizer_factory import DataflowPlanOptimization from metricflow.dataset.convert_semantic_model import SemanticModelToDataSetConverter from metricflow.dataset.dataset_classes import DataSet from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet @@ -502,7 +500,7 @@ def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> Me dataflow_plan = self._dataflow_plan_builder.build_plan( query_spec=query_spec, output_selection_specs=output_selection_specs, - optimizers=(SourceScanOptimizer(),), + optimizations=(DataflowPlanOptimization.SOURCE_SCAN,), ) else: dataflow_plan = self._dataflow_plan_builder.build_plan_for_distinct_values(query_spec=query_spec) diff --git a/tests_metricflow/query_rendering/compare_rendered_query.py b/tests_metricflow/query_rendering/compare_rendered_query.py index 83d2feda1b..d99a019bf7 100644 --- a/tests_metricflow/query_rendering/compare_rendered_query.py +++ b/tests_metricflow/query_rendering/compare_rendered_query.py @@ -8,7 +8,7 @@ from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder -from metricflow.dataflow.optimizer.predicate_pushdown_optimizer import PredicatePushdownOptimizer +from metricflow.dataflow.optimizer.dataflow_optimizer_factory import DataflowPlanOptimization from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter from metricflow.protocols.sql_client import SqlClient from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel @@ -52,11 +52,11 @@ def render_and_check( ) # Run dataflow -> sql conversion with all optimizers + optimizations = (DataflowPlanOptimization.PREDICATE_PUSHDOWN,) if is_distinct_values_plan: - # TODO: Make optimization available for distinct values plans - optimized_plan = base_plan + optimized_plan = dataflow_plan_builder.build_plan_for_distinct_values(query_spec, optimizations=optimizations) else: - optimized_plan = dataflow_plan_builder.build_plan(query_spec, optimizers=(PredicatePushdownOptimizer(),)) + optimized_plan = dataflow_plan_builder.build_plan(query_spec, optimizations=optimizations) conversion_result = dataflow_to_sql_converter.convert_to_sql_query_plan( sql_engine_type=sql_client.sql_engine_type, dataflow_plan_node=optimized_plan.sink_node, diff --git a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/BigQuery/test_distinct_values_query_with_metric_filter__plan0_optimized.sql b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/BigQuery/test_distinct_values_query_with_metric_filter__plan0_optimized.sql index fc77fa7686..cde9c960c5 100644 --- a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/BigQuery/test_distinct_values_query_with_metric_filter__plan0_optimized.sql +++ b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/BigQuery/test_distinct_values_query_with_metric_filter__plan0_optimized.sql @@ -6,7 +6,7 @@ FROM ( -- Join Standard Outputs SELECT lux_listing_mapping_src_28000.listing_id AS listing - , subq_19.listing__bookings AS listing__bookings + , subq_23.listing__bookings AS listing__bookings FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000 FULL OUTER JOIN ( -- Aggregate Measures @@ -23,13 +23,13 @@ FROM ( listing_id AS listing , 1 AS bookings FROM ***************************.fct_bookings bookings_source_src_28000 - ) subq_16 + ) subq_20 GROUP BY listing - ) subq_19 + ) subq_23 ON - lux_listing_mapping_src_28000.listing_id = subq_19.listing -) subq_20 + lux_listing_mapping_src_28000.listing_id = subq_23.listing +) subq_24 WHERE listing__bookings > 2 GROUP BY listing diff --git a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Databricks/test_distinct_values_query_with_metric_filter__plan0_optimized.sql b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Databricks/test_distinct_values_query_with_metric_filter__plan0_optimized.sql index fc77fa7686..cde9c960c5 100644 --- a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Databricks/test_distinct_values_query_with_metric_filter__plan0_optimized.sql +++ b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Databricks/test_distinct_values_query_with_metric_filter__plan0_optimized.sql @@ -6,7 +6,7 @@ FROM ( -- Join Standard Outputs SELECT lux_listing_mapping_src_28000.listing_id AS listing - , subq_19.listing__bookings AS listing__bookings + , subq_23.listing__bookings AS listing__bookings FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000 FULL OUTER JOIN ( -- Aggregate Measures @@ -23,13 +23,13 @@ FROM ( listing_id AS listing , 1 AS bookings FROM ***************************.fct_bookings bookings_source_src_28000 - ) subq_16 + ) subq_20 GROUP BY listing - ) subq_19 + ) subq_23 ON - lux_listing_mapping_src_28000.listing_id = subq_19.listing -) subq_20 + lux_listing_mapping_src_28000.listing_id = subq_23.listing +) subq_24 WHERE listing__bookings > 2 GROUP BY listing diff --git a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/DuckDB/test_distinct_values_query_with_metric_filter__plan0_optimized.sql b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/DuckDB/test_distinct_values_query_with_metric_filter__plan0_optimized.sql index fc77fa7686..cde9c960c5 100644 --- a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/DuckDB/test_distinct_values_query_with_metric_filter__plan0_optimized.sql +++ b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/DuckDB/test_distinct_values_query_with_metric_filter__plan0_optimized.sql @@ -6,7 +6,7 @@ FROM ( -- Join Standard Outputs SELECT lux_listing_mapping_src_28000.listing_id AS listing - , subq_19.listing__bookings AS listing__bookings + , subq_23.listing__bookings AS listing__bookings FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000 FULL OUTER JOIN ( -- Aggregate Measures @@ -23,13 +23,13 @@ FROM ( listing_id AS listing , 1 AS bookings FROM ***************************.fct_bookings bookings_source_src_28000 - ) subq_16 + ) subq_20 GROUP BY listing - ) subq_19 + ) subq_23 ON - lux_listing_mapping_src_28000.listing_id = subq_19.listing -) subq_20 + lux_listing_mapping_src_28000.listing_id = subq_23.listing +) subq_24 WHERE listing__bookings > 2 GROUP BY listing diff --git a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Postgres/test_distinct_values_query_with_metric_filter__plan0_optimized.sql b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Postgres/test_distinct_values_query_with_metric_filter__plan0_optimized.sql index fc77fa7686..cde9c960c5 100644 --- a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Postgres/test_distinct_values_query_with_metric_filter__plan0_optimized.sql +++ b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Postgres/test_distinct_values_query_with_metric_filter__plan0_optimized.sql @@ -6,7 +6,7 @@ FROM ( -- Join Standard Outputs SELECT lux_listing_mapping_src_28000.listing_id AS listing - , subq_19.listing__bookings AS listing__bookings + , subq_23.listing__bookings AS listing__bookings FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000 FULL OUTER JOIN ( -- Aggregate Measures @@ -23,13 +23,13 @@ FROM ( listing_id AS listing , 1 AS bookings FROM ***************************.fct_bookings bookings_source_src_28000 - ) subq_16 + ) subq_20 GROUP BY listing - ) subq_19 + ) subq_23 ON - lux_listing_mapping_src_28000.listing_id = subq_19.listing -) subq_20 + lux_listing_mapping_src_28000.listing_id = subq_23.listing +) subq_24 WHERE listing__bookings > 2 GROUP BY listing diff --git a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Redshift/test_distinct_values_query_with_metric_filter__plan0_optimized.sql b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Redshift/test_distinct_values_query_with_metric_filter__plan0_optimized.sql index fc77fa7686..cde9c960c5 100644 --- a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Redshift/test_distinct_values_query_with_metric_filter__plan0_optimized.sql +++ b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Redshift/test_distinct_values_query_with_metric_filter__plan0_optimized.sql @@ -6,7 +6,7 @@ FROM ( -- Join Standard Outputs SELECT lux_listing_mapping_src_28000.listing_id AS listing - , subq_19.listing__bookings AS listing__bookings + , subq_23.listing__bookings AS listing__bookings FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000 FULL OUTER JOIN ( -- Aggregate Measures @@ -23,13 +23,13 @@ FROM ( listing_id AS listing , 1 AS bookings FROM ***************************.fct_bookings bookings_source_src_28000 - ) subq_16 + ) subq_20 GROUP BY listing - ) subq_19 + ) subq_23 ON - lux_listing_mapping_src_28000.listing_id = subq_19.listing -) subq_20 + lux_listing_mapping_src_28000.listing_id = subq_23.listing +) subq_24 WHERE listing__bookings > 2 GROUP BY listing diff --git a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Snowflake/test_distinct_values_query_with_metric_filter__plan0_optimized.sql b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Snowflake/test_distinct_values_query_with_metric_filter__plan0_optimized.sql index fc77fa7686..cde9c960c5 100644 --- a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Snowflake/test_distinct_values_query_with_metric_filter__plan0_optimized.sql +++ b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Snowflake/test_distinct_values_query_with_metric_filter__plan0_optimized.sql @@ -6,7 +6,7 @@ FROM ( -- Join Standard Outputs SELECT lux_listing_mapping_src_28000.listing_id AS listing - , subq_19.listing__bookings AS listing__bookings + , subq_23.listing__bookings AS listing__bookings FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000 FULL OUTER JOIN ( -- Aggregate Measures @@ -23,13 +23,13 @@ FROM ( listing_id AS listing , 1 AS bookings FROM ***************************.fct_bookings bookings_source_src_28000 - ) subq_16 + ) subq_20 GROUP BY listing - ) subq_19 + ) subq_23 ON - lux_listing_mapping_src_28000.listing_id = subq_19.listing -) subq_20 + lux_listing_mapping_src_28000.listing_id = subq_23.listing +) subq_24 WHERE listing__bookings > 2 GROUP BY listing diff --git a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Trino/test_distinct_values_query_with_metric_filter__plan0_optimized.sql b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Trino/test_distinct_values_query_with_metric_filter__plan0_optimized.sql index fc77fa7686..cde9c960c5 100644 --- a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Trino/test_distinct_values_query_with_metric_filter__plan0_optimized.sql +++ b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/Trino/test_distinct_values_query_with_metric_filter__plan0_optimized.sql @@ -6,7 +6,7 @@ FROM ( -- Join Standard Outputs SELECT lux_listing_mapping_src_28000.listing_id AS listing - , subq_19.listing__bookings AS listing__bookings + , subq_23.listing__bookings AS listing__bookings FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000 FULL OUTER JOIN ( -- Aggregate Measures @@ -23,13 +23,13 @@ FROM ( listing_id AS listing , 1 AS bookings FROM ***************************.fct_bookings bookings_source_src_28000 - ) subq_16 + ) subq_20 GROUP BY listing - ) subq_19 + ) subq_23 ON - lux_listing_mapping_src_28000.listing_id = subq_19.listing -) subq_20 + lux_listing_mapping_src_28000.listing_id = subq_23.listing +) subq_24 WHERE listing__bookings > 2 GROUP BY listing