Skip to content

Commit

Permalink
Cache building operations in DataflowPlanBuilder (#1448)
Browse files Browse the repository at this point in the history
This PR adds a few LRU caches to handle building operations within the
`DataflowPlanBuilder`. Since the same metric may be used multiple times
in a derived metric (or between queries), there can be significant
performance improvements. Please view by commit as there were signature
/ type changes to making the changes cleaner.
  • Loading branch information
plypaul authored Oct 11, 2024
1 parent a68b992 commit 804bee5
Show file tree
Hide file tree
Showing 12 changed files with 251 additions and 103 deletions.
78 changes: 78 additions & 0 deletions metricflow/dataflow/builder/builder_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

from metricflow_semantics.collection_helpers.lru_cache import LruCache
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.specs.metric_spec import MetricSpec
from metricflow_semantics.specs.where_filter.where_filter_transform import WhereSpecFactory

from metricflow.dataflow.builder.measure_spec_properties import MeasureSpecProperties
from metricflow.dataflow.builder.source_node_recipe import SourceNodeRecipe
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.plan_conversion.node_processor import PredicatePushdownState


@dataclass(frozen=True)
class FindSourceNodeRecipeParameterSet:
"""Parameters for `DataflowPlanBuilder._find_source_node_recipe()`."""

linkable_spec_set: LinkableSpecSet
predicate_pushdown_state: PredicatePushdownState
measure_spec_properties: Optional[MeasureSpecProperties]


@dataclass(frozen=True)
class FindSourceNodeRecipeResult:
"""Result for `DataflowPlanBuilder._find_source_node_recipe()`."""

source_node_recipe: Optional[SourceNodeRecipe]


@dataclass(frozen=True)
class BuildAnyMetricOutputNodeParameterSet:
"""Parameters for `DataflowPlanBuilder._build_any_metric_output_node()`."""

metric_spec: MetricSpec
queried_linkable_specs: LinkableSpecSet
filter_spec_factory: WhereSpecFactory
predicate_pushdown_state: PredicatePushdownState
for_group_by_source_node: bool


class DataflowPlanBuilderCache:
"""Cache for internal methods in `DataflowPlanBuilder`."""

def __init__( # noqa: D107
self, find_source_node_recipe_cache_size: int = 1000, build_any_metric_output_node_cache_size: int = 1000
) -> None:
self._find_source_node_recipe_cache = LruCache[FindSourceNodeRecipeParameterSet, FindSourceNodeRecipeResult](
find_source_node_recipe_cache_size
)
self._build_any_metric_output_node_cache = LruCache[BuildAnyMetricOutputNodeParameterSet, DataflowPlanNode](
build_any_metric_output_node_cache_size
)

assert find_source_node_recipe_cache_size > 0
assert build_any_metric_output_node_cache_size > 0

def get_find_source_node_recipe_result( # noqa: D102
self, parameter_set: FindSourceNodeRecipeParameterSet
) -> Optional[FindSourceNodeRecipeResult]:
return self._find_source_node_recipe_cache.get(parameter_set)

def set_find_source_node_recipe_result( # noqa: D102
self, parameter_set: FindSourceNodeRecipeParameterSet, source_node_recipe: FindSourceNodeRecipeResult
) -> None:
self._find_source_node_recipe_cache.set(parameter_set, source_node_recipe)

def get_build_any_metric_output_node_result( # noqa: D102
self, parameter_set: BuildAnyMetricOutputNodeParameterSet
) -> Optional[DataflowPlanNode]:
return self._build_any_metric_output_node_cache.get(parameter_set)

def set_build_any_metric_output_node_result( # noqa: D102
self, parameter_set: BuildAnyMetricOutputNodeParameterSet, dataflow_plan_node: DataflowPlanNode
) -> None:
self._build_any_metric_output_node_cache.set(parameter_set, dataflow_plan_node)
163 changes: 94 additions & 69 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
import time
from dataclasses import dataclass
from typing import Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, Union

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
Expand Down Expand Up @@ -58,13 +57,20 @@
from metricflow_semantics.time.granularity import ExpandedTimeGranularity
from metricflow_semantics.time.time_spine_source import TimeSpineSource

from metricflow.dataflow.builder.builder_cache import (
BuildAnyMetricOutputNodeParameterSet,
DataflowPlanBuilderCache,
FindSourceNodeRecipeParameterSet,
FindSourceNodeRecipeResult,
)
from metricflow.dataflow.builder.measure_spec_properties import MeasureSpecProperties
from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
from metricflow.dataflow.builder.node_evaluator import (
JoinLinkableInstancesRecipe,
LinkableInstanceSatisfiabilityEvaluation,
NodeEvaluatorForLinkableInstances,
)
from metricflow.dataflow.builder.source_node import SourceNodeBuilder, SourceNodeSet
from metricflow.dataflow.builder.source_node_recipe import SourceNodeRecipe
from metricflow.dataflow.dataflow_plan import (
DataflowPlan,
DataflowPlanNode,
Expand All @@ -77,7 +83,7 @@
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
Expand All @@ -101,30 +107,6 @@
logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class SourceNodeRecipe:
"""Get a recipe for how to build a dataflow plan node that outputs measures and linkable instances as needed."""

source_node: DataflowPlanNode
required_local_linkable_specs: Tuple[LinkableInstanceSpec, ...]
join_linkable_instances_recipes: Tuple[JoinLinkableInstancesRecipe, ...]

@property
def join_targets(self) -> List[JoinDescription]:
"""Joins to be made to source node."""
return [join_recipe.join_description for join_recipe in self.join_linkable_instances_recipes]


@dataclass(frozen=True)
class MeasureSpecProperties:
"""Input dataclass for grouping properties of a sequence of MeasureSpecs."""

measure_specs: Sequence[MeasureSpec]
semantic_model_name: str
agg_time_dimension: TimeDimensionReference
non_additive_dimension_spec: Optional[NonAdditiveDimensionSpec] = None


class DataflowPlanBuilder:
"""Builds a dataflow plan to satisfy a given query."""

Expand All @@ -135,6 +117,7 @@ def __init__( # noqa: D107
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
column_association_resolver: ColumnAssociationResolver,
source_node_builder: SourceNodeBuilder,
dataflow_plan_builder_cache: Optional[DataflowPlanBuilderCache] = None,
) -> None:
self._semantic_model_lookup = semantic_manifest_lookup.semantic_model_lookup
self._metric_lookup = semantic_manifest_lookup.metric_lookup
Expand All @@ -144,6 +127,7 @@ def __init__( # noqa: D107
self._node_data_set_resolver = node_output_resolver
self._source_node_builder = source_node_builder
self._time_period_adjuster = DateutilTimePeriodAdjuster()
self._cache = dataflow_plan_builder_cache or DataflowPlanBuilderCache()

def build_plan(
self,
Expand Down Expand Up @@ -278,15 +262,19 @@ def _build_aggregated_conversion_node(
filter_specs=base_measure_spec.filter_specs,
)
base_measure_recipe = self._find_source_node_recipe(
measure_spec_properties=self._build_measure_spec_properties([base_measure_spec.measure_spec]),
predicate_pushdown_state=time_range_only_pushdown_state,
linkable_spec_set=base_required_linkable_specs,
FindSourceNodeRecipeParameterSet(
measure_spec_properties=self._build_measure_spec_properties([base_measure_spec.measure_spec]),
predicate_pushdown_state=time_range_only_pushdown_state,
linkable_spec_set=base_required_linkable_specs,
)
)
logger.debug(LazyFormat(lambda: f"Recipe for base measure aggregation:\n{mf_pformat(base_measure_recipe)}"))
conversion_measure_recipe = self._find_source_node_recipe(
measure_spec_properties=self._build_measure_spec_properties([conversion_measure_spec.measure_spec]),
predicate_pushdown_state=disabled_pushdown_state,
linkable_spec_set=LinkableSpecSet(),
FindSourceNodeRecipeParameterSet(
measure_spec_properties=self._build_measure_spec_properties([conversion_measure_spec.measure_spec]),
predicate_pushdown_state=disabled_pushdown_state,
linkable_spec_set=LinkableSpecSet(),
)
)
logger.debug(
LazyFormat(lambda: f"Recipe for conversion measure aggregation:\n{mf_pformat(conversion_measure_recipe)}")
Expand Down Expand Up @@ -620,18 +608,21 @@ def _build_derived_metric_output_node(

parent_nodes.append(
self._build_any_metric_output_node(
metric_spec=MetricSpec(
element_name=metric_input_spec.element_name,
filter_specs=tuple(filter_specs),
alias=metric_input_spec.alias,
offset_window=metric_input_spec.offset_window,
offset_to_grain=metric_input_spec.offset_to_grain,
),
queried_linkable_specs=(
queried_linkable_specs if not metric_spec.has_time_offset else required_linkable_specs
),
filter_spec_factory=filter_spec_factory,
predicate_pushdown_state=metric_pushdown_state,
BuildAnyMetricOutputNodeParameterSet(
metric_spec=MetricSpec(
element_name=metric_input_spec.element_name,
filter_specs=tuple(filter_specs),
alias=metric_input_spec.alias,
offset_window=metric_input_spec.offset_window,
offset_to_grain=metric_input_spec.offset_to_grain,
),
queried_linkable_specs=(
queried_linkable_specs if not metric_spec.has_time_offset else required_linkable_specs
),
filter_spec_factory=filter_spec_factory,
predicate_pushdown_state=metric_pushdown_state,
for_group_by_source_node=False,
)
)
)

Expand Down Expand Up @@ -676,15 +667,26 @@ def _build_derived_metric_output_node(
)
return output_node

def _build_any_metric_output_node(
self,
metric_spec: MetricSpec,
queried_linkable_specs: LinkableSpecSet,
filter_spec_factory: WhereSpecFactory,
predicate_pushdown_state: PredicatePushdownState,
for_group_by_source_node: bool = False,
def _build_any_metric_output_node(self, parameter_set: BuildAnyMetricOutputNodeParameterSet) -> DataflowPlanNode:
"""Builds a node to compute a metric of any type."""
result = self._cache.get_build_any_metric_output_node_result(parameter_set)
if result is not None:
return result

result = self._build_any_metric_output_node_non_cached(parameter_set)
self._cache.set_build_any_metric_output_node_result(parameter_set, result)
return result

def _build_any_metric_output_node_non_cached(
self, parameter_set: BuildAnyMetricOutputNodeParameterSet
) -> DataflowPlanNode:
"""Builds a node to compute a metric of any type."""
metric_spec = parameter_set.metric_spec
queried_linkable_specs = parameter_set.queried_linkable_specs
filter_spec_factory = parameter_set.filter_spec_factory
predicate_pushdown_state = parameter_set.predicate_pushdown_state
for_group_by_source_node = parameter_set.for_group_by_source_node

metric = self._metric_lookup.get_metric(metric_spec.reference)

if metric.type is MetricType.SIMPLE:
Expand Down Expand Up @@ -750,11 +752,13 @@ def _build_metrics_output_node(

output_nodes.append(
self._build_any_metric_output_node(
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
filter_spec_factory=filter_spec_factory,
predicate_pushdown_state=predicate_pushdown_state,
for_group_by_source_node=for_group_by_source_node,
BuildAnyMetricOutputNodeParameterSet(
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
filter_spec_factory=filter_spec_factory,
predicate_pushdown_state=predicate_pushdown_state,
for_group_by_source_node=for_group_by_source_node,
)
)
)

Expand Down Expand Up @@ -798,10 +802,14 @@ def _build_plan_for_distinct_values(
queried_linkable_specs=query_spec.linkable_specs, filter_specs=query_level_filter_specs
)
predicate_pushdown_state = PredicatePushdownState(
time_range_constraint=query_spec.time_range_constraint, where_filter_specs=query_level_filter_specs
time_range_constraint=query_spec.time_range_constraint, where_filter_specs=tuple(query_level_filter_specs)
)
dataflow_recipe = self._find_source_node_recipe(
linkable_spec_set=required_linkable_specs, predicate_pushdown_state=predicate_pushdown_state
FindSourceNodeRecipeParameterSet(
linkable_spec_set=required_linkable_specs,
predicate_pushdown_state=predicate_pushdown_state,
measure_spec_properties=None,
)
)
if not dataflow_recipe:
raise UnableToSatisfyQueryError(f"Unable to join all items in request: {required_linkable_specs}")
Expand Down Expand Up @@ -968,19 +976,34 @@ def _build_measure_spec_properties(self, measure_specs: Sequence[MeasureSpec]) -
if measure_agg_time_dimension != agg_time_dimension:
raise ValueError(f"measure_specs {measure_specs} do not have the same agg_time_dimension.")
return MeasureSpecProperties(
measure_specs=measure_specs,
measure_specs=tuple(measure_specs),
semantic_model_name=semantic_model_name,
agg_time_dimension=agg_time_dimension,
non_additive_dimension_spec=non_additive_dimension_spec,
)

def _find_source_node_recipe(
self,
linkable_spec_set: LinkableSpecSet,
predicate_pushdown_state: PredicatePushdownState,
measure_spec_properties: Optional[MeasureSpecProperties] = None,
) -> Optional[SourceNodeRecipe]:
def _find_source_node_recipe(self, parameter_set: FindSourceNodeRecipeParameterSet) -> Optional[SourceNodeRecipe]:
"""Find the most suitable source nodes to satisfy the requested specs, as well as how to join them."""
result = self._cache.get_find_source_node_recipe_result(parameter_set)
if result is not None:
return result.source_node_recipe
source_node_recipe = self._find_source_node_recipe_non_cached(parameter_set)
self._cache.set_find_source_node_recipe_result(parameter_set, FindSourceNodeRecipeResult(source_node_recipe))
if source_node_recipe is not None:
return SourceNodeRecipe(
source_node=source_node_recipe.source_node,
required_local_linkable_specs=source_node_recipe.required_local_linkable_specs,
join_linkable_instances_recipes=source_node_recipe.join_linkable_instances_recipes,
)
return None

def _find_source_node_recipe_non_cached(
self, parameter_set: FindSourceNodeRecipeParameterSet
) -> Optional[SourceNodeRecipe]:
linkable_spec_set = parameter_set.linkable_spec_set
predicate_pushdown_state = parameter_set.predicate_pushdown_state
measure_spec_properties = parameter_set.measure_spec_properties

candidate_nodes_for_left_side_of_join: List[DataflowPlanNode] = []
candidate_nodes_for_right_side_of_join: List[DataflowPlanNode] = []

Expand Down Expand Up @@ -1529,9 +1552,11 @@ def _build_aggregated_measure_from_measure_source_node(

find_recipe_start_time = time.time()
measure_recipe = self._find_source_node_recipe(
measure_spec_properties=measure_properties,
predicate_pushdown_state=measure_pushdown_state,
linkable_spec_set=required_linkable_specs,
FindSourceNodeRecipeParameterSet(
measure_spec_properties=measure_properties,
predicate_pushdown_state=measure_pushdown_state,
linkable_spec_set=required_linkable_specs,
)
)
logger.debug(
LazyFormat(
Expand Down
18 changes: 18 additions & 0 deletions metricflow/dataflow/builder/measure_spec_properties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Sequence

from dbt_semantic_interfaces.references import TimeDimensionReference
from metricflow_semantics.specs.measure_spec import MeasureSpec
from metricflow_semantics.specs.non_additive_dimension_spec import NonAdditiveDimensionSpec


@dataclass(frozen=True)
class MeasureSpecProperties:
"""Input dataclass for grouping properties of a sequence of MeasureSpecs."""

measure_specs: Sequence[MeasureSpec]
semantic_model_name: str
agg_time_dimension: TimeDimensionReference
non_additive_dimension_spec: Optional[NonAdditiveDimensionSpec] = None
24 changes: 24 additions & 0 deletions metricflow/dataflow/builder/source_node_recipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Tuple

from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec

from metricflow.dataflow.builder.node_evaluator import JoinLinkableInstancesRecipe
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription


@dataclass(frozen=True)
class SourceNodeRecipe:
"""Get a recipe for how to build a dataflow plan node that outputs measures and linkable instances as needed."""

source_node: DataflowPlanNode
required_local_linkable_specs: Tuple[LinkableInstanceSpec, ...]
join_linkable_instances_recipes: Tuple[JoinLinkableInstancesRecipe, ...]

@property
def join_targets(self) -> List[JoinDescription]:
"""Joins to be made to source node."""
return [join_recipe.join_description for join_recipe in self.join_linkable_instances_recipes]
Loading

0 comments on commit 804bee5

Please sign in to comment.