diff --git a/metricflow/dataflow/builder/builder_cache.py b/metricflow/dataflow/builder/builder_cache.py new file mode 100644 index 0000000000..79444561ad --- /dev/null +++ b/metricflow/dataflow/builder/builder_cache.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +from metricflow_semantics.collection_helpers.lru_cache import LruCache +from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet +from metricflow_semantics.specs.metric_spec import MetricSpec +from metricflow_semantics.specs.where_filter.where_filter_transform import WhereSpecFactory + +from metricflow.dataflow.builder.measure_spec_properties import MeasureSpecProperties +from metricflow.dataflow.builder.source_node_recipe import SourceNodeRecipe +from metricflow.dataflow.dataflow_plan import DataflowPlanNode +from metricflow.plan_conversion.node_processor import PredicatePushdownState + + +@dataclass(frozen=True) +class FindSourceNodeRecipeParameterSet: + """Parameters for `DataflowPlanBuilder._find_source_node_recipe()`.""" + + linkable_spec_set: LinkableSpecSet + predicate_pushdown_state: PredicatePushdownState + measure_spec_properties: Optional[MeasureSpecProperties] + + +@dataclass(frozen=True) +class FindSourceNodeRecipeResult: + """Result for `DataflowPlanBuilder._find_source_node_recipe()`.""" + + source_node_recipe: Optional[SourceNodeRecipe] + + +@dataclass(frozen=True) +class BuildAnyMetricOutputNodeParameterSet: + """Parameters for `DataflowPlanBuilder._build_any_metric_output_node()`.""" + + metric_spec: MetricSpec + queried_linkable_specs: LinkableSpecSet + filter_spec_factory: WhereSpecFactory + predicate_pushdown_state: PredicatePushdownState + for_group_by_source_node: bool + + +class DataflowPlanBuilderCache: + """Cache for internal methods in `DataflowPlanBuilder`.""" + + def __init__( # noqa: D107 + self, find_source_node_recipe_cache_size: int = 1000, build_any_metric_output_node_cache_size: int = 1000 + ) -> None: + self._find_source_node_recipe_cache = LruCache[FindSourceNodeRecipeParameterSet, FindSourceNodeRecipeResult]( + find_source_node_recipe_cache_size + ) + self._build_any_metric_output_node_cache = LruCache[BuildAnyMetricOutputNodeParameterSet, DataflowPlanNode]( + build_any_metric_output_node_cache_size + ) + + assert find_source_node_recipe_cache_size > 0 + assert build_any_metric_output_node_cache_size > 0 + + def get_find_source_node_recipe_result( # noqa: D102 + self, parameter_set: FindSourceNodeRecipeParameterSet + ) -> Optional[FindSourceNodeRecipeResult]: + return self._find_source_node_recipe_cache.get(parameter_set) + + def set_find_source_node_recipe_result( # noqa: D102 + self, parameter_set: FindSourceNodeRecipeParameterSet, source_node_recipe: FindSourceNodeRecipeResult + ) -> None: + self._find_source_node_recipe_cache.set(parameter_set, source_node_recipe) + + def get_build_any_metric_output_node_result( # noqa: D102 + self, parameter_set: BuildAnyMetricOutputNodeParameterSet + ) -> Optional[DataflowPlanNode]: + return self._build_any_metric_output_node_cache.get(parameter_set) + + def set_build_any_metric_output_node_result( # noqa: D102 + self, parameter_set: BuildAnyMetricOutputNodeParameterSet, dataflow_plan_node: DataflowPlanNode + ) -> None: + self._build_any_metric_output_node_cache.set(parameter_set, dataflow_plan_node) diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 8a2efece15..40b938eb99 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -57,6 +57,12 @@ from metricflow_semantics.time.granularity import ExpandedTimeGranularity from metricflow_semantics.time.time_spine_source import TimeSpineSource +from metricflow.dataflow.builder.builder_cache import ( + BuildAnyMetricOutputNodeParameterSet, + DataflowPlanBuilderCache, + FindSourceNodeRecipeParameterSet, + FindSourceNodeRecipeResult, +) from metricflow.dataflow.builder.measure_spec_properties import MeasureSpecProperties from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver from metricflow.dataflow.builder.node_evaluator import ( @@ -111,6 +117,7 @@ def __init__( # noqa: D107 node_output_resolver: DataflowPlanNodeOutputDataSetResolver, column_association_resolver: ColumnAssociationResolver, source_node_builder: SourceNodeBuilder, + dataflow_plan_builder_cache: Optional[DataflowPlanBuilderCache] = None, ) -> None: self._semantic_model_lookup = semantic_manifest_lookup.semantic_model_lookup self._metric_lookup = semantic_manifest_lookup.metric_lookup @@ -120,6 +127,7 @@ def __init__( # noqa: D107 self._node_data_set_resolver = node_output_resolver self._source_node_builder = source_node_builder self._time_period_adjuster = DateutilTimePeriodAdjuster() + self._cache = dataflow_plan_builder_cache or DataflowPlanBuilderCache() def build_plan( self, @@ -254,15 +262,19 @@ def _build_aggregated_conversion_node( filter_specs=base_measure_spec.filter_specs, ) base_measure_recipe = self._find_source_node_recipe( - measure_spec_properties=self._build_measure_spec_properties([base_measure_spec.measure_spec]), - predicate_pushdown_state=time_range_only_pushdown_state, - linkable_spec_set=base_required_linkable_specs, + FindSourceNodeRecipeParameterSet( + measure_spec_properties=self._build_measure_spec_properties([base_measure_spec.measure_spec]), + predicate_pushdown_state=time_range_only_pushdown_state, + linkable_spec_set=base_required_linkable_specs, + ) ) logger.debug(LazyFormat(lambda: f"Recipe for base measure aggregation:\n{mf_pformat(base_measure_recipe)}")) conversion_measure_recipe = self._find_source_node_recipe( - measure_spec_properties=self._build_measure_spec_properties([conversion_measure_spec.measure_spec]), - predicate_pushdown_state=disabled_pushdown_state, - linkable_spec_set=LinkableSpecSet(), + FindSourceNodeRecipeParameterSet( + measure_spec_properties=self._build_measure_spec_properties([conversion_measure_spec.measure_spec]), + predicate_pushdown_state=disabled_pushdown_state, + linkable_spec_set=LinkableSpecSet(), + ) ) logger.debug( LazyFormat(lambda: f"Recipe for conversion measure aggregation:\n{mf_pformat(conversion_measure_recipe)}") @@ -596,18 +608,21 @@ def _build_derived_metric_output_node( parent_nodes.append( self._build_any_metric_output_node( - metric_spec=MetricSpec( - element_name=metric_input_spec.element_name, - filter_specs=tuple(filter_specs), - alias=metric_input_spec.alias, - offset_window=metric_input_spec.offset_window, - offset_to_grain=metric_input_spec.offset_to_grain, - ), - queried_linkable_specs=( - queried_linkable_specs if not metric_spec.has_time_offset else required_linkable_specs - ), - filter_spec_factory=filter_spec_factory, - predicate_pushdown_state=metric_pushdown_state, + BuildAnyMetricOutputNodeParameterSet( + metric_spec=MetricSpec( + element_name=metric_input_spec.element_name, + filter_specs=tuple(filter_specs), + alias=metric_input_spec.alias, + offset_window=metric_input_spec.offset_window, + offset_to_grain=metric_input_spec.offset_to_grain, + ), + queried_linkable_specs=( + queried_linkable_specs if not metric_spec.has_time_offset else required_linkable_specs + ), + filter_spec_factory=filter_spec_factory, + predicate_pushdown_state=metric_pushdown_state, + for_group_by_source_node=False, + ) ) ) @@ -652,15 +667,26 @@ def _build_derived_metric_output_node( ) return output_node - def _build_any_metric_output_node( - self, - metric_spec: MetricSpec, - queried_linkable_specs: LinkableSpecSet, - filter_spec_factory: WhereSpecFactory, - predicate_pushdown_state: PredicatePushdownState, - for_group_by_source_node: bool = False, + def _build_any_metric_output_node(self, parameter_set: BuildAnyMetricOutputNodeParameterSet) -> DataflowPlanNode: + """Builds a node to compute a metric of any type.""" + result = self._cache.get_build_any_metric_output_node_result(parameter_set) + if result is not None: + return result + + result = self._build_any_metric_output_node_non_cached(parameter_set) + self._cache.set_build_any_metric_output_node_result(parameter_set, result) + return result + + def _build_any_metric_output_node_non_cached( + self, parameter_set: BuildAnyMetricOutputNodeParameterSet ) -> DataflowPlanNode: """Builds a node to compute a metric of any type.""" + metric_spec = parameter_set.metric_spec + queried_linkable_specs = parameter_set.queried_linkable_specs + filter_spec_factory = parameter_set.filter_spec_factory + predicate_pushdown_state = parameter_set.predicate_pushdown_state + for_group_by_source_node = parameter_set.for_group_by_source_node + metric = self._metric_lookup.get_metric(metric_spec.reference) if metric.type is MetricType.SIMPLE: @@ -726,11 +752,13 @@ def _build_metrics_output_node( output_nodes.append( self._build_any_metric_output_node( - metric_spec=metric_spec, - queried_linkable_specs=queried_linkable_specs, - filter_spec_factory=filter_spec_factory, - predicate_pushdown_state=predicate_pushdown_state, - for_group_by_source_node=for_group_by_source_node, + BuildAnyMetricOutputNodeParameterSet( + metric_spec=metric_spec, + queried_linkable_specs=queried_linkable_specs, + filter_spec_factory=filter_spec_factory, + predicate_pushdown_state=predicate_pushdown_state, + for_group_by_source_node=for_group_by_source_node, + ) ) ) @@ -777,7 +805,11 @@ def _build_plan_for_distinct_values( time_range_constraint=query_spec.time_range_constraint, where_filter_specs=tuple(query_level_filter_specs) ) dataflow_recipe = self._find_source_node_recipe( - linkable_spec_set=required_linkable_specs, predicate_pushdown_state=predicate_pushdown_state + FindSourceNodeRecipeParameterSet( + linkable_spec_set=required_linkable_specs, + predicate_pushdown_state=predicate_pushdown_state, + measure_spec_properties=None, + ) ) if not dataflow_recipe: raise UnableToSatisfyQueryError(f"Unable to join all items in request: {required_linkable_specs}") @@ -944,19 +976,34 @@ def _build_measure_spec_properties(self, measure_specs: Sequence[MeasureSpec]) - if measure_agg_time_dimension != agg_time_dimension: raise ValueError(f"measure_specs {measure_specs} do not have the same agg_time_dimension.") return MeasureSpecProperties( - measure_specs=measure_specs, + measure_specs=tuple(measure_specs), semantic_model_name=semantic_model_name, agg_time_dimension=agg_time_dimension, non_additive_dimension_spec=non_additive_dimension_spec, ) - def _find_source_node_recipe( - self, - linkable_spec_set: LinkableSpecSet, - predicate_pushdown_state: PredicatePushdownState, - measure_spec_properties: Optional[MeasureSpecProperties] = None, - ) -> Optional[SourceNodeRecipe]: + def _find_source_node_recipe(self, parameter_set: FindSourceNodeRecipeParameterSet) -> Optional[SourceNodeRecipe]: """Find the most suitable source nodes to satisfy the requested specs, as well as how to join them.""" + result = self._cache.get_find_source_node_recipe_result(parameter_set) + if result is not None: + return result.source_node_recipe + source_node_recipe = self._find_source_node_recipe_non_cached(parameter_set) + self._cache.set_find_source_node_recipe_result(parameter_set, FindSourceNodeRecipeResult(source_node_recipe)) + if source_node_recipe is not None: + return SourceNodeRecipe( + source_node=source_node_recipe.source_node, + required_local_linkable_specs=source_node_recipe.required_local_linkable_specs, + join_linkable_instances_recipes=source_node_recipe.join_linkable_instances_recipes, + ) + return None + + def _find_source_node_recipe_non_cached( + self, parameter_set: FindSourceNodeRecipeParameterSet + ) -> Optional[SourceNodeRecipe]: + linkable_spec_set = parameter_set.linkable_spec_set + predicate_pushdown_state = parameter_set.predicate_pushdown_state + measure_spec_properties = parameter_set.measure_spec_properties + candidate_nodes_for_left_side_of_join: List[DataflowPlanNode] = [] candidate_nodes_for_right_side_of_join: List[DataflowPlanNode] = [] @@ -1505,9 +1552,11 @@ def _build_aggregated_measure_from_measure_source_node( find_recipe_start_time = time.time() measure_recipe = self._find_source_node_recipe( - measure_spec_properties=measure_properties, - predicate_pushdown_state=measure_pushdown_state, - linkable_spec_set=required_linkable_specs, + FindSourceNodeRecipeParameterSet( + measure_spec_properties=measure_properties, + predicate_pushdown_state=measure_pushdown_state, + linkable_spec_set=required_linkable_specs, + ) ) logger.debug( LazyFormat( diff --git a/metricflow/engine/metricflow_engine.py b/metricflow/engine/metricflow_engine.py index 31104bad14..fa572f7215 100644 --- a/metricflow/engine/metricflow_engine.py +++ b/metricflow/engine/metricflow_engine.py @@ -36,6 +36,7 @@ from metricflow_semantics.time.time_spine_source import TimeSpineSource from metricflow.data_table.mf_table import MetricFlowDataTable +from metricflow.dataflow.builder.builder_cache import DataflowPlanBuilderCache from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver from metricflow.dataflow.builder.source_node import SourceNodeBuilder @@ -395,12 +396,14 @@ def __init__( ) node_output_resolver.cache_output_data_sets(source_node_set.all_nodes) + self._dataflow_plan_builder_cache = DataflowPlanBuilderCache() self._dataflow_plan_builder = DataflowPlanBuilder( source_node_set=source_node_set, semantic_manifest_lookup=self._semantic_manifest_lookup, column_association_resolver=self._column_association_resolver, node_output_resolver=node_output_resolver, source_node_builder=source_node_builder, + dataflow_plan_builder_cache=self._dataflow_plan_builder_cache, ) self._to_sql_query_plan_converter = DataflowToSqlQueryPlanConverter( column_association_resolver=self._column_association_resolver,