Skip to content

Commit

Permalink
Cache common operations in dataflow plan builder.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Oct 10, 2024
1 parent 19af13d commit 56f2d0a
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 41 deletions.
78 changes: 78 additions & 0 deletions metricflow/dataflow/builder/builder_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

from metricflow_semantics.collection_helpers.lru_cache import LruCache
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.specs.metric_spec import MetricSpec
from metricflow_semantics.specs.where_filter.where_filter_transform import WhereSpecFactory

from metricflow.dataflow.builder.measure_spec_properties import MeasureSpecProperties
from metricflow.dataflow.builder.source_node_recipe import SourceNodeRecipe
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.plan_conversion.node_processor import PredicatePushdownState


@dataclass(frozen=True)
class FindSourceNodeRecipeParameterSet:
"""Parameters for `DataflowPlanBuilder._find_source_node_recipe()`."""

linkable_spec_set: LinkableSpecSet
predicate_pushdown_state: PredicatePushdownState
measure_spec_properties: Optional[MeasureSpecProperties]


@dataclass(frozen=True)
class FindSourceNodeRecipeResult:
"""Result for `DataflowPlanBuilder._find_source_node_recipe()`."""

source_node_recipe: Optional[SourceNodeRecipe]


@dataclass(frozen=True)
class BuildAnyMetricOutputNodeParameterSet:
"""Parameters for `DataflowPlanBuilder._build_any_metric_output_node()`."""

metric_spec: MetricSpec
queried_linkable_specs: LinkableSpecSet
filter_spec_factory: WhereSpecFactory
predicate_pushdown_state: PredicatePushdownState
for_group_by_source_node: bool


class DataflowPlanBuilderCache:
"""Cache for internal methods in `DataflowPlanBuilder`."""

def __init__( # noqa: D107
self, find_source_node_recipe_cache_size: int = 1000, build_any_metric_output_node_cache_size: int = 1000
) -> None:
self._find_source_node_recipe_cache = LruCache[FindSourceNodeRecipeParameterSet, FindSourceNodeRecipeResult](
find_source_node_recipe_cache_size
)
self._build_any_metric_output_node_cache = LruCache[BuildAnyMetricOutputNodeParameterSet, DataflowPlanNode](
build_any_metric_output_node_cache_size
)

assert find_source_node_recipe_cache_size > 0
assert build_any_metric_output_node_cache_size > 0

def get_find_source_node_recipe_result( # noqa: D102
self, parameter_set: FindSourceNodeRecipeParameterSet
) -> Optional[FindSourceNodeRecipeResult]:
return self._find_source_node_recipe_cache.get(parameter_set)

def set_find_source_node_recipe_result( # noqa: D102
self, parameter_set: FindSourceNodeRecipeParameterSet, source_node_recipe: FindSourceNodeRecipeResult
) -> None:
self._find_source_node_recipe_cache.set(parameter_set, source_node_recipe)

def get_build_any_metric_output_node_result( # noqa: D102
self, parameter_set: BuildAnyMetricOutputNodeParameterSet
) -> Optional[DataflowPlanNode]:
return self._build_any_metric_output_node_cache.get(parameter_set)

def set_build_any_metric_output_node_result( # noqa: D102
self, parameter_set: BuildAnyMetricOutputNodeParameterSet, dataflow_plan_node: DataflowPlanNode
) -> None:
self._build_any_metric_output_node_cache.set(parameter_set, dataflow_plan_node)
131 changes: 90 additions & 41 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@
from metricflow_semantics.time.granularity import ExpandedTimeGranularity
from metricflow_semantics.time.time_spine_source import TimeSpineSource

from metricflow.dataflow.builder.builder_cache import (
BuildAnyMetricOutputNodeParameterSet,
DataflowPlanBuilderCache,
FindSourceNodeRecipeParameterSet,
FindSourceNodeRecipeResult,
)
from metricflow.dataflow.builder.measure_spec_properties import MeasureSpecProperties
from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
from metricflow.dataflow.builder.node_evaluator import (
Expand Down Expand Up @@ -111,6 +117,7 @@ def __init__( # noqa: D107
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
column_association_resolver: ColumnAssociationResolver,
source_node_builder: SourceNodeBuilder,
dataflow_plan_builder_cache: Optional[DataflowPlanBuilderCache] = None,
) -> None:
self._semantic_model_lookup = semantic_manifest_lookup.semantic_model_lookup
self._metric_lookup = semantic_manifest_lookup.metric_lookup
Expand All @@ -120,6 +127,7 @@ def __init__( # noqa: D107
self._node_data_set_resolver = node_output_resolver
self._source_node_builder = source_node_builder
self._time_period_adjuster = DateutilTimePeriodAdjuster()
self._cache = dataflow_plan_builder_cache or DataflowPlanBuilderCache()

def build_plan(
self,
Expand Down Expand Up @@ -254,15 +262,19 @@ def _build_aggregated_conversion_node(
filter_specs=base_measure_spec.filter_specs,
)
base_measure_recipe = self._find_source_node_recipe(
measure_spec_properties=self._build_measure_spec_properties([base_measure_spec.measure_spec]),
predicate_pushdown_state=time_range_only_pushdown_state,
linkable_spec_set=base_required_linkable_specs,
FindSourceNodeRecipeParameterSet(
measure_spec_properties=self._build_measure_spec_properties([base_measure_spec.measure_spec]),
predicate_pushdown_state=time_range_only_pushdown_state,
linkable_spec_set=base_required_linkable_specs,
)
)
logger.debug(LazyFormat(lambda: f"Recipe for base measure aggregation:\n{mf_pformat(base_measure_recipe)}"))
conversion_measure_recipe = self._find_source_node_recipe(
measure_spec_properties=self._build_measure_spec_properties([conversion_measure_spec.measure_spec]),
predicate_pushdown_state=disabled_pushdown_state,
linkable_spec_set=LinkableSpecSet(),
FindSourceNodeRecipeParameterSet(
measure_spec_properties=self._build_measure_spec_properties([conversion_measure_spec.measure_spec]),
predicate_pushdown_state=disabled_pushdown_state,
linkable_spec_set=LinkableSpecSet(),
)
)
logger.debug(
LazyFormat(lambda: f"Recipe for conversion measure aggregation:\n{mf_pformat(conversion_measure_recipe)}")
Expand Down Expand Up @@ -596,18 +608,21 @@ def _build_derived_metric_output_node(

parent_nodes.append(
self._build_any_metric_output_node(
metric_spec=MetricSpec(
element_name=metric_input_spec.element_name,
filter_specs=tuple(filter_specs),
alias=metric_input_spec.alias,
offset_window=metric_input_spec.offset_window,
offset_to_grain=metric_input_spec.offset_to_grain,
),
queried_linkable_specs=(
queried_linkable_specs if not metric_spec.has_time_offset else required_linkable_specs
),
filter_spec_factory=filter_spec_factory,
predicate_pushdown_state=metric_pushdown_state,
BuildAnyMetricOutputNodeParameterSet(
metric_spec=MetricSpec(
element_name=metric_input_spec.element_name,
filter_specs=tuple(filter_specs),
alias=metric_input_spec.alias,
offset_window=metric_input_spec.offset_window,
offset_to_grain=metric_input_spec.offset_to_grain,
),
queried_linkable_specs=(
queried_linkable_specs if not metric_spec.has_time_offset else required_linkable_specs
),
filter_spec_factory=filter_spec_factory,
predicate_pushdown_state=metric_pushdown_state,
for_group_by_source_node=False,
)
)
)

Expand Down Expand Up @@ -652,15 +667,26 @@ def _build_derived_metric_output_node(
)
return output_node

def _build_any_metric_output_node(
self,
metric_spec: MetricSpec,
queried_linkable_specs: LinkableSpecSet,
filter_spec_factory: WhereSpecFactory,
predicate_pushdown_state: PredicatePushdownState,
for_group_by_source_node: bool = False,
def _build_any_metric_output_node(self, parameter_set: BuildAnyMetricOutputNodeParameterSet) -> DataflowPlanNode:
"""Builds a node to compute a metric of any type."""
result = self._cache.get_build_any_metric_output_node_result(parameter_set)
if result is not None:
return result

result = self._build_any_metric_output_node_non_cached(parameter_set)
self._cache.set_build_any_metric_output_node_result(parameter_set, result)
return result

def _build_any_metric_output_node_non_cached(
self, parameter_set: BuildAnyMetricOutputNodeParameterSet
) -> DataflowPlanNode:
"""Builds a node to compute a metric of any type."""
metric_spec = parameter_set.metric_spec
queried_linkable_specs = parameter_set.queried_linkable_specs
filter_spec_factory = parameter_set.filter_spec_factory
predicate_pushdown_state = parameter_set.predicate_pushdown_state
for_group_by_source_node = parameter_set.for_group_by_source_node

metric = self._metric_lookup.get_metric(metric_spec.reference)

if metric.type is MetricType.SIMPLE:
Expand Down Expand Up @@ -726,11 +752,13 @@ def _build_metrics_output_node(

output_nodes.append(
self._build_any_metric_output_node(
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
filter_spec_factory=filter_spec_factory,
predicate_pushdown_state=predicate_pushdown_state,
for_group_by_source_node=for_group_by_source_node,
BuildAnyMetricOutputNodeParameterSet(
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
filter_spec_factory=filter_spec_factory,
predicate_pushdown_state=predicate_pushdown_state,
for_group_by_source_node=for_group_by_source_node,
)
)
)

Expand Down Expand Up @@ -777,7 +805,11 @@ def _build_plan_for_distinct_values(
time_range_constraint=query_spec.time_range_constraint, where_filter_specs=tuple(query_level_filter_specs)
)
dataflow_recipe = self._find_source_node_recipe(
linkable_spec_set=required_linkable_specs, predicate_pushdown_state=predicate_pushdown_state
FindSourceNodeRecipeParameterSet(
linkable_spec_set=required_linkable_specs,
predicate_pushdown_state=predicate_pushdown_state,
measure_spec_properties=None,
)
)
if not dataflow_recipe:
raise UnableToSatisfyQueryError(f"Unable to join all items in request: {required_linkable_specs}")
Expand Down Expand Up @@ -944,19 +976,34 @@ def _build_measure_spec_properties(self, measure_specs: Sequence[MeasureSpec]) -
if measure_agg_time_dimension != agg_time_dimension:
raise ValueError(f"measure_specs {measure_specs} do not have the same agg_time_dimension.")
return MeasureSpecProperties(
measure_specs=measure_specs,
measure_specs=tuple(measure_specs),
semantic_model_name=semantic_model_name,
agg_time_dimension=agg_time_dimension,
non_additive_dimension_spec=non_additive_dimension_spec,
)

def _find_source_node_recipe(
self,
linkable_spec_set: LinkableSpecSet,
predicate_pushdown_state: PredicatePushdownState,
measure_spec_properties: Optional[MeasureSpecProperties] = None,
) -> Optional[SourceNodeRecipe]:
def _find_source_node_recipe(self, parameter_set: FindSourceNodeRecipeParameterSet) -> Optional[SourceNodeRecipe]:
"""Find the most suitable source nodes to satisfy the requested specs, as well as how to join them."""
result = self._cache.get_find_source_node_recipe_result(parameter_set)
if result is not None:
return result.source_node_recipe
source_node_recipe = self._find_source_node_recipe_non_cached(parameter_set)
self._cache.set_find_source_node_recipe_result(parameter_set, FindSourceNodeRecipeResult(source_node_recipe))
if source_node_recipe is not None:
return SourceNodeRecipe(
source_node=source_node_recipe.source_node,
required_local_linkable_specs=source_node_recipe.required_local_linkable_specs,
join_linkable_instances_recipes=source_node_recipe.join_linkable_instances_recipes,
)
return None

def _find_source_node_recipe_non_cached(
self, parameter_set: FindSourceNodeRecipeParameterSet
) -> Optional[SourceNodeRecipe]:
linkable_spec_set = parameter_set.linkable_spec_set
predicate_pushdown_state = parameter_set.predicate_pushdown_state
measure_spec_properties = parameter_set.measure_spec_properties

candidate_nodes_for_left_side_of_join: List[DataflowPlanNode] = []
candidate_nodes_for_right_side_of_join: List[DataflowPlanNode] = []

Expand Down Expand Up @@ -1505,9 +1552,11 @@ def _build_aggregated_measure_from_measure_source_node(

find_recipe_start_time = time.time()
measure_recipe = self._find_source_node_recipe(
measure_spec_properties=measure_properties,
predicate_pushdown_state=measure_pushdown_state,
linkable_spec_set=required_linkable_specs,
FindSourceNodeRecipeParameterSet(
measure_spec_properties=measure_properties,
predicate_pushdown_state=measure_pushdown_state,
linkable_spec_set=required_linkable_specs,
)
)
logger.debug(
LazyFormat(
Expand Down
3 changes: 3 additions & 0 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from metricflow_semantics.time.time_spine_source import TimeSpineSource

from metricflow.data_table.mf_table import MetricFlowDataTable
from metricflow.dataflow.builder.builder_cache import DataflowPlanBuilderCache
from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder
from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
from metricflow.dataflow.builder.source_node import SourceNodeBuilder
Expand Down Expand Up @@ -395,12 +396,14 @@ def __init__(
)
node_output_resolver.cache_output_data_sets(source_node_set.all_nodes)

self._dataflow_plan_builder_cache = DataflowPlanBuilderCache()
self._dataflow_plan_builder = DataflowPlanBuilder(
source_node_set=source_node_set,
semantic_manifest_lookup=self._semantic_manifest_lookup,
column_association_resolver=self._column_association_resolver,
node_output_resolver=node_output_resolver,
source_node_builder=source_node_builder,
dataflow_plan_builder_cache=self._dataflow_plan_builder_cache,
)
self._to_sql_query_plan_converter = DataflowToSqlQueryPlanConverter(
column_association_resolver=self._column_association_resolver,
Expand Down

0 comments on commit 56f2d0a

Please sign in to comment.