Skip to content

Commit

Permalink
Add a method to figure out common nodes in a dataflow plan (#1520)
Browse files Browse the repository at this point in the history
This PR adds a method to figure out nodes in a dataflow plan that appear
more than once. i.e. a node that is the parent of multiple nodes. These
common nodes indicate operations where a computation is reused, e.g. a
metric that is used in the computation of multiple derived metrics in a
query. These nodes will be later used to generate CTEs.
  • Loading branch information
plypaul authored Nov 13, 2024
1 parent 12c94c2 commit 668c5f8
Show file tree
Hide file tree
Showing 9 changed files with 457 additions and 31 deletions.
2 changes: 1 addition & 1 deletion metricflow-semantics/metricflow_semantics/dag/mf_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class DisplayedProperty: # type: ignore
value: Any # type: ignore


@dataclass(frozen=True)
@dataclass(frozen=True, order=True)
class NodeId:
"""Unique identifier for nodes in DAGs."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,15 +437,19 @@ def mf_pformat_dict( # type: ignore
include_none_object_fields: bool = False,
include_empty_object_fields: bool = False,
preserve_raw_strings: bool = False,
pad_items_with_newlines: bool = False,
) -> str:
"""Prints many objects in an indented form.
If preserve_raw_strings is set, and a value of the obj_dict is of type str, then use the value itself, not the
If `preserve_raw_strings` is set, and a value of the obj_dict is of type str, then use the value itself, not the
representation of the string. e.g. if value="foo", then "foo" instead of "'foo'". Useful for values that contain
newlines.
If `pad_items_with_newlines` is set , each key / value section is padded with newlines.
"""
lines: List[str] = [description] if description is not None else []
description_lines: List[str] = [description] if description is not None else []
obj_dict = obj_dict or {}
item_sections = []
for key, value in obj_dict.items():
if preserve_raw_strings and isinstance(value, str):
value_str = value
Expand All @@ -460,21 +464,25 @@ def mf_pformat_dict( # type: ignore
)

lines_in_value_str = len(value_str.split("\n"))
item_block_lines: Tuple[str, ...]
item_section_lines: Tuple[str, ...]
if lines_in_value_str > 1:
item_block_lines = (
item_section_lines = (
f"{key}:",
indent(
value_str,
indent_prefix=indent_prefix,
),
)
else:
item_block_lines = (f"{key}: {value_str}",)
item_block = "\n".join(item_block_lines)
item_section_lines = (f"{key}: {value_str}",)
item_section = "\n".join(item_section_lines)

if description is None:
lines.append(item_block)
item_sections.append(item_section)
else:
lines.append(indent(item_block))
return "\n".join(lines)
item_sections.append(indent(item_section))

if pad_items_with_newlines:
return "\n\n".join(description_lines + item_sections)
else:
return "\n".join(description_lines + item_sections)
2 changes: 1 addition & 1 deletion metricflow-semantics/metricflow_semantics/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC
from typing import TypeVar

VisitorOutputT = TypeVar("VisitorOutputT")
VisitorOutputT = TypeVar("VisitorOutputT", covariant=True)


class Visitable(ABC):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,19 @@ def test_pformat_dict_with_empty_message() -> None:
)
== result
)


def test_pformat_dict_with_pad_sections_with_newline() -> None:
"""Test `mf_pformat_dict` with new lines between sections."""
result = mf_pformat_dict(obj_dict={"object_0": (1, 2, 3), "object_1": {4: 5}}, pad_items_with_newlines=True)

assert (
mf_dedent(
"""
object_0: (1, 2, 3)
object_1: {4: 5}
"""
)
== result
)
11 changes: 11 additions & 0 deletions metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import functools
import logging
import typing
from abc import ABC, abstractmethod
Expand All @@ -23,7 +24,11 @@

NodeSelfT = TypeVar("NodeSelfT", bound="DataflowPlanNode")

# Make it so that we only have to suppress errors here instead of both at the method and the class.
ComparisonAnyType = typing.Any # type: ignore[misc]


@functools.total_ordering
@dataclass(frozen=True, eq=False)
class DataflowPlanNode(DagNode["DataflowPlanNode"], Visitable, ABC):
"""A node in the graph representation of the dataflow.
Expand Down Expand Up @@ -81,6 +86,12 @@ def aggregated_to_elements(self) -> Set[LinkableInstanceSpec]:
"""Indicates that the node has been aggregated to these specs, guaranteeing uniqueness in all combinations."""
return set()

def __lt__(self, other: ComparisonAnyType) -> bool: # noqa: D105
if not isinstance(other, DataflowPlanNode):
raise NotImplementedError

return self.node_id < other.node_id


class DataflowPlan(MetricFlowDag[DataflowPlanNode]):
"""Describes the flow of metric data as it goes from source nodes to sink nodes in the graph."""
Expand Down
79 changes: 79 additions & 0 deletions metricflow/dataflow/dataflow_plan_analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from __future__ import annotations

from collections import defaultdict
from typing import Dict, FrozenSet, Mapping, Sequence, Set

from typing_extensions import override

from metricflow.dataflow.dataflow_plan import DataflowPlan, DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitorWithDefaultHandler


class DataflowPlanAnalyzer:
"""Class to determine more complex properties of the dataflow plan.
These could also be made as member methods of the dataflow plan, but this requires resolving some circular
dependency issues to break out the functionality into separate files.
"""

@staticmethod
def find_common_branches(dataflow_plan: DataflowPlan) -> Sequence[DataflowPlanNode]:
"""Starting from the sink node, find the common branches that exist in the associated DAG.
Returns a sorted sequence for reproducibility.
"""
counting_visitor = _CountDataflowNodeVisitor()
dataflow_plan.sink_node.accept(counting_visitor)

node_to_common_count = counting_visitor.get_node_counts()

common_nodes = []
for node, count in node_to_common_count.items():
if count > 1:
common_nodes.append(node)

common_branches_visitor = _FindLargestCommonBranchesVisitor(frozenset(common_nodes))

return tuple(sorted(dataflow_plan.sink_node.accept(common_branches_visitor)))


class _CountDataflowNodeVisitor(DataflowPlanNodeVisitorWithDefaultHandler[None]):
"""Helper visitor to build a dict from a node in the plan to the number of times it appears in the plan."""

def __init__(self) -> None:
self._node_to_count: Dict[DataflowPlanNode, int] = defaultdict(int)

def get_node_counts(self) -> Mapping[DataflowPlanNode, int]:
return self._node_to_count

@override
def _default_handler(self, node: DataflowPlanNode) -> None:
for parent_node in node.parent_nodes:
parent_node.accept(self)
self._node_to_count[node] += 1


class _FindLargestCommonBranchesVisitor(DataflowPlanNodeVisitorWithDefaultHandler[FrozenSet[DataflowPlanNode]]):
"""Given the nodes that are known to appear in the DAG multiple times, find the common branches.
To get the largest common branches, (e.g. for `A -> B -> C -> D` and `B -> C -> D`, both `B -> C -> D`
and `C -> D` can be considered common branches, and we want the largest one), this uses preorder traversal and
returns the first common node that is seen.
"""

def __init__(self, common_nodes: FrozenSet[DataflowPlanNode]) -> None:
self._common_nodes = common_nodes

@override
def _default_handler(self, node: DataflowPlanNode) -> FrozenSet[DataflowPlanNode]:
# Traversal starts from the leaf node and then goes to the parent branches. By doing this check first, we don't
# return smaller common branches that are a part of a larger common branch.
if node in self._common_nodes:
return frozenset({node})

common_branch_leaf_nodes: Set[DataflowPlanNode] = set()

for parent_node in node.parent_nodes:
common_branch_leaf_nodes.update(parent_node.accept(self))

return frozenset(common_branch_leaf_nodes)
Loading

0 comments on commit 668c5f8

Please sign in to comment.