Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precompute / Cache Outputs for Nodes in SourceNodeSet #1030

Merged
merged 5 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 4 additions & 20 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
from metricflow.mf_logging.pretty_print import mf_pformat
from metricflow.mf_logging.runtime import log_runtime
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.plan_conversion.column_resolver import DunderColumnAssociationResolver
from metricflow.plan_conversion.node_processor import PreJoinNodeProcessor
from metricflow.query.group_by_item.filter_spec_resolution.filter_location import WhereFilterLocation
from metricflow.query.group_by_item.filter_spec_resolution.filter_spec_lookup import FilterSpecResolutionLookUp
Expand Down Expand Up @@ -122,30 +121,15 @@ def __init__( # noqa: D
self,
source_node_set: SourceNodeSet,
semantic_manifest_lookup: SemanticManifestLookup,
node_output_resolver: Optional[DataflowPlanNodeOutputDataSetResolver] = None,
column_association_resolver: Optional[ColumnAssociationResolver] = None,
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
column_association_resolver: ColumnAssociationResolver,
) -> None:
self._semantic_model_lookup = semantic_manifest_lookup.semantic_model_lookup
self._metric_lookup = semantic_manifest_lookup.metric_lookup
self._metric_time_dimension_reference = DataSet.metric_time_dimension_reference()
self._source_node_set = source_node_set
self._column_association_resolver = (
DunderColumnAssociationResolver(semantic_manifest_lookup)
if not column_association_resolver
else column_association_resolver
)
self._node_data_set_resolver = (
DataflowPlanNodeOutputDataSetResolver(
column_association_resolver=(
DunderColumnAssociationResolver(semantic_manifest_lookup)
if not column_association_resolver
else column_association_resolver
),
semantic_manifest_lookup=semantic_manifest_lookup,
)
if not node_output_resolver
else node_output_resolver
)
self._column_association_resolver = column_association_resolver
self._node_data_set_resolver = node_output_resolver

def build_plan(
self,
Expand Down
25 changes: 22 additions & 3 deletions metricflow/dataflow/builder/node_data_set.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from typing import Dict
from typing import Dict, Optional, Sequence

from metricflow.dataflow.dataflow_plan import (
DataflowPlanNode,
)
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.mf_logging.runtime import log_block_runtime
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.specs.column_assoc import ColumnAssociationResolver
Expand Down Expand Up @@ -58,16 +59,34 @@ def __init__( # noqa: D
self,
column_association_resolver: ColumnAssociationResolver,
semantic_manifest_lookup: SemanticManifestLookup,
_node_to_output_data_set: Optional[Dict[DataflowPlanNode, SqlDataSet]] = None,
) -> None:
self._node_to_output_data_set: Dict[DataflowPlanNode, SqlDataSet] = {}
self._node_to_output_data_set: Dict[DataflowPlanNode, SqlDataSet] = _node_to_output_data_set or {}
super().__init__(
column_association_resolver=column_association_resolver,
semantic_manifest_lookup=semantic_manifest_lookup,
)

def get_output_data_set(self, node: DataflowPlanNode) -> SqlDataSet: # noqa: D
"""Cached since this will be called repeatedly during the computation of multiple metrics."""
"""Cached since this will be called repeatedly during the computation of multiple metrics.

# TODO: The cache needs to be pruned, but has not yet been an issue.
"""
if node not in self._node_to_output_data_set:
self._node_to_output_data_set[node] = node.accept(self)

return self._node_to_output_data_set[node]

def cache_output_data_sets(self, nodes: Sequence[DataflowPlanNode]) -> None:
"""Cache the output of the given nodes for consistent retrieval with `get_output_data_set`."""
with log_block_runtime(f"cache_output_data_sets for {len(nodes)} nodes"):
for node in nodes:
self.get_output_data_set(node)

def copy(self) -> DataflowPlanNodeOutputDataSetResolver:
"""Return a copy of this with the same nodes cached."""
return DataflowPlanNodeOutputDataSetResolver(
column_association_resolver=self.column_association_resolver,
semantic_manifest_lookup=self._semantic_manifest_lookup,
_node_to_output_data_set=dict(self._node_to_output_data_set),
)
8 changes: 7 additions & 1 deletion metricflow/dataflow/builder/source_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,18 @@ class SourceNodeSet:
source_nodes_for_metric_queries: Tuple[BaseOutput, ...]

# Semantic models are 1:1 mapped to a ReadSqlSourceNode. The tuple also contains the same `time_spine_node` as
# below.
# below. See usage in `DataflowPlanBuilder`.
source_nodes_for_group_by_item_queries: Tuple[BaseOutput, ...]

# Provides the time spine.
time_spine_node: MetricTimeDimensionTransformNode

@property
def all_nodes(self) -> Sequence[BaseOutput]: # noqa: D
return (
self.source_nodes_for_metric_queries + self.source_nodes_for_group_by_item_queries + (self.time_spine_node,)
)


class SourceNodeBuilder:
"""Helps build a `SourceNodeSet` - refer to that class for more details."""
Expand Down
15 changes: 13 additions & 2 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dbt_semantic_interfaces.type_enums import DimensionType

from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder
from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
from metricflow.dataflow.builder.source_node import SourceNodeBuilder
from metricflow.dataflow.dataflow_plan import DataflowPlan
from metricflow.dataflow.optimizer.source_scan.source_scan_optimizer import (
Expand Down Expand Up @@ -271,7 +272,7 @@ def get_dimension_values(
"""Retrieves a list of dimension values given a [metric_name, get_group_by_values].

Args:
metric_name: Names of metrics that contain the group_by.
metric_names: Names of metrics that contain the group_by.
get_group_by_values: Name of group_by to get values from.
time_constraint_start: Get data for the start of this time range.
time_constraint_end: Get data for the end of this time range.
Expand All @@ -294,8 +295,10 @@ def explain_get_dimension_values( # noqa: D
"""Returns the SQL query for get_dimension_values.

Args:
metric_name: Names of metrics that contain the group_by.
metric_names: Names of metrics that contain the group_by.
metrics: Similar to `metric_names`, but specified via parameter objects.
get_group_by_values: Name of group_by to get values from.
group_by: Similar to `get_group_by_values`, but specified via parameter objects.
time_constraint_start: Get data for the start of this time range.
time_constraint_end: Get data for the end of this time range.

Expand Down Expand Up @@ -351,9 +354,17 @@ def __init__(
)
source_node_set = source_node_builder.create_from_data_sets(self._source_data_sets)

node_output_resolver = DataflowPlanNodeOutputDataSetResolver(
column_association_resolver=self._column_association_resolver,
semantic_manifest_lookup=self._semantic_manifest_lookup,
)
node_output_resolver.cache_output_data_sets(source_node_set.all_nodes)

self._dataflow_plan_builder = DataflowPlanBuilder(
source_node_set=source_node_set,
semantic_manifest_lookup=self._semantic_manifest_lookup,
column_association_resolver=self._column_association_resolver,
node_output_resolver=node_output_resolver,
)
self._to_sql_query_plan_converter = DataflowToSqlQueryPlanConverter(
column_association_resolver=self._column_association_resolver,
Expand Down
20 changes: 18 additions & 2 deletions metricflow/mf_logging/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import functools
import logging
import time
from typing import Callable, TypeVar
from contextlib import contextmanager
from typing import Callable, Iterator, TypeVar

from typing_extensions import ParamSpec

Expand All @@ -14,7 +15,7 @@


def log_runtime(
runtime_warning_threshold: float = 3.0,
runtime_warning_threshold: float = 5.0,
) -> Callable[[Callable[ParametersType, ReturnType]], Callable[ParametersType, ReturnType]]:
"""Logs how long a function took to run.

Expand Down Expand Up @@ -43,3 +44,18 @@ def _inner(*args: ParametersType.args, **kwargs: ParametersType.kwargs) -> Retur
return _inner

return decorator


@contextmanager
def log_block_runtime(code_block_name: str, runtime_warning_threshold: float = 5.0) -> Iterator[None]:
"""Logs the runtime of the enclosed code block."""
start_time = time.time()
description = f"code_block_name={repr(code_block_name)}"
logger.info(f"Starting {description}")

yield

runtime = time.time() - start_time
logger.info(f"Finished {description} in {runtime:.1f}s")
if runtime > runtime_warning_threshold:
logger.warning(f"{description} is slow with a runtime of {runtime:.1f}s")
1 change: 1 addition & 0 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(
semantic_manifest_lookup: Self-explanatory.
"""
self._column_association_resolver = column_association_resolver
self._semantic_manifest_lookup = semantic_manifest_lookup
self._metric_lookup = semantic_manifest_lookup.metric_lookup
self._semantic_model_lookup = semantic_manifest_lookup.semantic_model_lookup
self._time_spine_source = semantic_manifest_lookup.time_spine_source
Expand Down
16 changes: 9 additions & 7 deletions metricflow/test/fixtures/manifest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class MetricFlowEngineTestFixture:
query_parser: MetricFlowQueryParser
metricflow_engine: MetricFlowEngine

_node_output_resolver: DataflowPlanNodeOutputDataSetResolver

@staticmethod
def from_parameters( # noqa: D
sql_client: SqlClient, semantic_manifest: PydanticSemanticManifest
Expand All @@ -122,7 +124,11 @@ def from_parameters( # noqa: D
source_node_set = MetricFlowEngineTestFixture._data_set_to_source_node_set(
column_association_resolver, semantic_manifest_lookup, data_set_mapping
)

node_output_resolver = DataflowPlanNodeOutputDataSetResolver(
column_association_resolver=column_association_resolver,
semantic_manifest_lookup=semantic_manifest_lookup,
)
node_output_resolver.cache_output_data_sets(source_node_set.all_nodes)
query_parser = MetricFlowQueryParser(semantic_manifest_lookup=semantic_manifest_lookup)
return MetricFlowEngineTestFixture(
semantic_manifest=semantic_manifest,
Expand All @@ -131,6 +137,7 @@ def from_parameters( # noqa: D
data_set_mapping=data_set_mapping,
read_node_mapping=read_node_mapping,
source_node_set=source_node_set,
_node_output_resolver=node_output_resolver,
dataflow_to_sql_converter=DataflowToSqlQueryPlanConverter(
column_association_resolver=column_association_resolver,
semantic_manifest_lookup=semantic_manifest_lookup,
Expand All @@ -151,15 +158,10 @@ def dataflow_plan_builder(self) -> DataflowPlanBuilder:

This should be recreated for each test since DataflowPlanBuilder contains a stateful cache.
"""
node_output_resolver = DataflowPlanNodeOutputDataSetResolver(
column_association_resolver=self.column_association_resolver,
semantic_manifest_lookup=self.semantic_manifest_lookup,
)

return DataflowPlanBuilder(
source_node_set=self.source_node_set,
semantic_manifest_lookup=self.semantic_manifest_lookup,
node_output_resolver=node_output_resolver,
node_output_resolver=self._node_output_resolver.copy(),
column_association_resolver=self.column_association_resolver,
)

Expand Down
Loading
Loading