Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update API for requesting dataflow plan optimization #1278

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode
from metricflow.dataflow.optimizer.dataflow_plan_optimizer import DataflowPlanOptimizer
from metricflow.dataflow.optimizer.dataflow_optimizer_factory import (
DataflowPlanOptimization,
DataflowPlanOptimizerFactory,
)
from metricflow.dataset.dataset_classes import DataSet
from metricflow.plan_conversion.node_processor import (
PredicateInputType,
Expand Down Expand Up @@ -143,7 +146,7 @@ def build_plan(
query_spec: MetricFlowQuerySpec,
output_sql_table: Optional[SqlTable] = None,
output_selection_specs: Optional[InstanceSpecSet] = None,
optimizers: Sequence[DataflowPlanOptimizer] = (),
optimizations: Sequence[DataflowPlanOptimization] = (),
) -> DataflowPlan:
"""Generate a plan for reading the results of a query with the given spec into a data_table or table."""
# Workaround for a Pycharm type inspection issue with decorators.
Expand All @@ -152,7 +155,7 @@ def build_plan(
query_spec=query_spec,
output_sql_table=output_sql_table,
output_selection_specs=output_selection_specs,
optimizers=optimizers,
optimizations=optimizations,
)

def _build_query_output_node(
Expand Down Expand Up @@ -208,7 +211,7 @@ def _build_plan(
query_spec: MetricFlowQuerySpec,
output_sql_table: Optional[SqlTable],
output_selection_specs: Optional[InstanceSpecSet],
optimizers: Sequence[DataflowPlanOptimizer],
optimizations: Sequence[DataflowPlanOptimization],
) -> DataflowPlan:
metrics_output_node = self._build_query_output_node(query_spec=query_spec)

Expand All @@ -222,7 +225,11 @@ def _build_plan(

plan_id = DagId.from_id_prefix(StaticIdPrefix.DATAFLOW_PLAN_PREFIX)
plan = DataflowPlan(sink_nodes=[sink_node], plan_id=plan_id)
for optimizer in optimizers:
return self._optimize_plan(plan, optimizations)

def _optimize_plan(self, plan: DataflowPlan, optimizations: Sequence[DataflowPlanOptimization]) -> DataflowPlan:
optimizer_factory = DataflowPlanOptimizerFactory()
for optimizer in optimizer_factory.get_optimizers(optimizations):
logger.info(f"Applying {optimizer.__class__.__name__}")
try:
plan = optimizer.optimize(plan)
Expand Down Expand Up @@ -733,17 +740,21 @@ def _build_metrics_output_node(

return CombineAggregatedOutputsNode(parent_nodes=output_nodes)

def build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> DataflowPlan:
def build_plan_for_distinct_values(
self, query_spec: MetricFlowQuerySpec, optimizations: Sequence[DataflowPlanOptimization] = ()
) -> DataflowPlan:
"""Generate a plan that would get the distinct values of a linkable instance.

e.g. distinct listing__country_latest for bookings by listing__country_latest
"""
# Workaround for a Pycharm type inspection issue with decorators.
# noinspection PyArgumentList
return self._build_plan_for_distinct_values(query_spec)
return self._build_plan_for_distinct_values(query_spec, optimizations=optimizations)

@log_runtime()
def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> DataflowPlan:
def _build_plan_for_distinct_values(
self, query_spec: MetricFlowQuerySpec, optimizations: Sequence[DataflowPlanOptimization]
) -> DataflowPlan:
assert not query_spec.metric_specs, "Can't build distinct values plan with metrics."
query_level_filter_specs: Sequence[WhereFilterSpec] = ()
if query_spec.filter_intersection is not None and len(query_spec.filter_intersection.where_filters) > 0:
Expand Down Expand Up @@ -792,7 +803,8 @@ def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Da
parent_node=output_node, order_by_specs=query_spec.order_by_specs, limit=query_spec.limit
)

return DataflowPlan(sink_nodes=[sink_node])
plan = DataflowPlan(sink_nodes=[sink_node])
return self._optimize_plan(plan, optimizations)

@staticmethod
def build_sink_node(
Expand Down
38 changes: 38 additions & 0 deletions metricflow/dataflow/optimizer/dataflow_optimizer_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from enum import Enum
from typing import List, Sequence

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted

from metricflow.dataflow.optimizer.dataflow_plan_optimizer import DataflowPlanOptimizer
from metricflow.dataflow.optimizer.predicate_pushdown_optimizer import PredicatePushdownOptimizer
from metricflow.dataflow.optimizer.source_scan.source_scan_optimizer import SourceScanOptimizer


class DataflowPlanOptimization(Enum):
"""Enumeration of optimization types available for execution."""

PREDICATE_PUSHDOWN = "predicate_pushdown"
SOURCE_SCAN = "source_scan"


class DataflowPlanOptimizerFactory:
"""Factory class for initializing an enumerated set of optimizers.
This allows us to centralize initialization and, most importantly, share class instances with cached high cost
processing between the DataflowPlanBuilder and the optimizer instances requiring that functionality.
"""

def get_optimizers(self, optimizations: Sequence[DataflowPlanOptimization]) -> Sequence[DataflowPlanOptimizer]:
"""Initializes and returns a sequence of optimizers matching the input optimization requests."""
optimizers: List[DataflowPlanOptimizer] = []
for optimization in optimizations:
if optimization is DataflowPlanOptimization.SOURCE_SCAN:
optimizers.append(SourceScanOptimizer())
elif optimization is DataflowPlanOptimization.PREDICATE_PUSHDOWN:
optimizers.append(PredicatePushdownOptimizer())
else:
assert_values_exhausted(optimization)

return tuple(optimizers)
6 changes: 2 additions & 4 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@
from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
from metricflow.dataflow.builder.source_node import SourceNodeBuilder
from metricflow.dataflow.dataflow_plan import DataflowPlan
from metricflow.dataflow.optimizer.source_scan.source_scan_optimizer import (
SourceScanOptimizer,
)
from metricflow.dataflow.optimizer.dataflow_optimizer_factory import DataflowPlanOptimization
from metricflow.dataset.convert_semantic_model import SemanticModelToDataSetConverter
from metricflow.dataset.dataset_classes import DataSet
from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet
Expand Down Expand Up @@ -502,7 +500,7 @@ def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> Me
dataflow_plan = self._dataflow_plan_builder.build_plan(
query_spec=query_spec,
output_selection_specs=output_selection_specs,
optimizers=(SourceScanOptimizer(),),
optimizations=(DataflowPlanOptimization.SOURCE_SCAN,),
)
else:
dataflow_plan = self._dataflow_plan_builder.build_plan_for_distinct_values(query_spec=query_spec)
Expand Down
8 changes: 4 additions & 4 deletions tests_metricflow/query_rendering/compare_rendered_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration

from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder
from metricflow.dataflow.optimizer.predicate_pushdown_optimizer import PredicatePushdownOptimizer
from metricflow.dataflow.optimizer.dataflow_optimizer_factory import DataflowPlanOptimization
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.protocols.sql_client import SqlClient
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
Expand Down Expand Up @@ -52,11 +52,11 @@ def render_and_check(
)

# Run dataflow -> sql conversion with all optimizers
optimizations = (DataflowPlanOptimization.PREDICATE_PUSHDOWN,)
if is_distinct_values_plan:
# TODO: Make optimization available for distinct values plans
optimized_plan = base_plan
optimized_plan = dataflow_plan_builder.build_plan_for_distinct_values(query_spec, optimizations=optimizations)
else:
optimized_plan = dataflow_plan_builder.build_plan(query_spec, optimizers=(PredicatePushdownOptimizer(),))
optimized_plan = dataflow_plan_builder.build_plan(query_spec, optimizations=optimizations)
conversion_result = dataflow_to_sql_converter.convert_to_sql_query_plan(
sql_engine_type=sql_client.sql_engine_type,
dataflow_plan_node=optimized_plan.sink_node,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ FROM (
-- Join Standard Outputs
SELECT
lux_listing_mapping_src_28000.listing_id AS listing
, subq_19.listing__bookings AS listing__bookings
, subq_23.listing__bookings AS listing__bookings
FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000
FULL OUTER JOIN (
-- Aggregate Measures
Expand All @@ -23,13 +23,13 @@ FROM (
listing_id AS listing
, 1 AS bookings
FROM ***************************.fct_bookings bookings_source_src_28000
) subq_16
) subq_20
GROUP BY
listing
) subq_19
) subq_23
ON
lux_listing_mapping_src_28000.listing_id = subq_19.listing
) subq_20
lux_listing_mapping_src_28000.listing_id = subq_23.listing
) subq_24
WHERE listing__bookings > 2
GROUP BY
listing
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ FROM (
-- Join Standard Outputs
SELECT
lux_listing_mapping_src_28000.listing_id AS listing
, subq_19.listing__bookings AS listing__bookings
, subq_23.listing__bookings AS listing__bookings
FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000
FULL OUTER JOIN (
-- Aggregate Measures
Expand All @@ -23,13 +23,13 @@ FROM (
listing_id AS listing
, 1 AS bookings
FROM ***************************.fct_bookings bookings_source_src_28000
) subq_16
) subq_20
GROUP BY
listing
) subq_19
) subq_23
ON
lux_listing_mapping_src_28000.listing_id = subq_19.listing
) subq_20
lux_listing_mapping_src_28000.listing_id = subq_23.listing
) subq_24
WHERE listing__bookings > 2
GROUP BY
listing
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ FROM (
-- Join Standard Outputs
SELECT
lux_listing_mapping_src_28000.listing_id AS listing
, subq_19.listing__bookings AS listing__bookings
, subq_23.listing__bookings AS listing__bookings
FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000
FULL OUTER JOIN (
-- Aggregate Measures
Expand All @@ -23,13 +23,13 @@ FROM (
listing_id AS listing
, 1 AS bookings
FROM ***************************.fct_bookings bookings_source_src_28000
) subq_16
) subq_20
GROUP BY
listing
) subq_19
) subq_23
ON
lux_listing_mapping_src_28000.listing_id = subq_19.listing
) subq_20
lux_listing_mapping_src_28000.listing_id = subq_23.listing
) subq_24
WHERE listing__bookings > 2
GROUP BY
listing
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ FROM (
-- Join Standard Outputs
SELECT
lux_listing_mapping_src_28000.listing_id AS listing
, subq_19.listing__bookings AS listing__bookings
, subq_23.listing__bookings AS listing__bookings
FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000
FULL OUTER JOIN (
-- Aggregate Measures
Expand All @@ -23,13 +23,13 @@ FROM (
listing_id AS listing
, 1 AS bookings
FROM ***************************.fct_bookings bookings_source_src_28000
) subq_16
) subq_20
GROUP BY
listing
) subq_19
) subq_23
ON
lux_listing_mapping_src_28000.listing_id = subq_19.listing
) subq_20
lux_listing_mapping_src_28000.listing_id = subq_23.listing
) subq_24
WHERE listing__bookings > 2
GROUP BY
listing
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ FROM (
-- Join Standard Outputs
SELECT
lux_listing_mapping_src_28000.listing_id AS listing
, subq_19.listing__bookings AS listing__bookings
, subq_23.listing__bookings AS listing__bookings
FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000
FULL OUTER JOIN (
-- Aggregate Measures
Expand All @@ -23,13 +23,13 @@ FROM (
listing_id AS listing
, 1 AS bookings
FROM ***************************.fct_bookings bookings_source_src_28000
) subq_16
) subq_20
GROUP BY
listing
) subq_19
) subq_23
ON
lux_listing_mapping_src_28000.listing_id = subq_19.listing
) subq_20
lux_listing_mapping_src_28000.listing_id = subq_23.listing
) subq_24
WHERE listing__bookings > 2
GROUP BY
listing
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ FROM (
-- Join Standard Outputs
SELECT
lux_listing_mapping_src_28000.listing_id AS listing
, subq_19.listing__bookings AS listing__bookings
, subq_23.listing__bookings AS listing__bookings
FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000
FULL OUTER JOIN (
-- Aggregate Measures
Expand All @@ -23,13 +23,13 @@ FROM (
listing_id AS listing
, 1 AS bookings
FROM ***************************.fct_bookings bookings_source_src_28000
) subq_16
) subq_20
GROUP BY
listing
) subq_19
) subq_23
ON
lux_listing_mapping_src_28000.listing_id = subq_19.listing
) subq_20
lux_listing_mapping_src_28000.listing_id = subq_23.listing
) subq_24
WHERE listing__bookings > 2
GROUP BY
listing
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ FROM (
-- Join Standard Outputs
SELECT
lux_listing_mapping_src_28000.listing_id AS listing
, subq_19.listing__bookings AS listing__bookings
, subq_23.listing__bookings AS listing__bookings
FROM ***************************.dim_lux_listing_id_mapping lux_listing_mapping_src_28000
FULL OUTER JOIN (
-- Aggregate Measures
Expand All @@ -23,13 +23,13 @@ FROM (
listing_id AS listing
, 1 AS bookings
FROM ***************************.fct_bookings bookings_source_src_28000
) subq_16
) subq_20
GROUP BY
listing
) subq_19
) subq_23
ON
lux_listing_mapping_src_28000.listing_id = subq_19.listing
) subq_20
lux_listing_mapping_src_28000.listing_id = subq_23.listing
) subq_24
WHERE listing__bookings > 2
GROUP BY
listing
Loading