Skip to content

Commit

Permalink
Convert DAG node classes to dataclasses (#1320)
Browse files Browse the repository at this point in the history
* Previously, the DAG-node classes in MF were not dataclasses due to
issues with class inheritance. These have since been resolved via
hierarchy changes, so this PR updates those classes to be dataclasses.
In general, dataclasses make these easier to use, and there are upcoming
use cases where dataclasses will simplify implementation (e.g. graph
component comparison, serialization).
* A `create()` method was added to simplify many initialization use
cases while not overriding the one generated by `dataclasses`.
* There is an update to how the `node_id` field is set - please see
`mf_dag.py`.
* Otherwise, this should be a mechanical update with no substantive
logic changes.
* There are no snapshot changes, so that should simplify review.
* Please view by commit.
  • Loading branch information
plypaul authored Jul 11, 2024
1 parent c484a81 commit 072b8d5
Show file tree
Hide file tree
Showing 70 changed files with 1,931 additions and 1,887 deletions.
44 changes: 27 additions & 17 deletions metricflow-semantics/metricflow_semantics/dag/mf_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import textwrap
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Sequence, TypeVar
from typing import Any, Generic, Optional, Sequence, Tuple, TypeVar

import jinja2
from typing_extensions import override

from metricflow_semantics.dag.dag_to_text import MetricFlowDagTextFormatter
from metricflow_semantics.dag.id_prefix import IdPrefix
from metricflow_semantics.dag.sequential_id import SequentialIdGenerator
from metricflow_semantics.mf_logging.pretty_formattable import MetricFlowPrettyFormattable
from metricflow_semantics.visitor import VisitorOutputT

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -52,16 +54,30 @@ def visit_node(self, node: DagNode) -> VisitorOutputT: # noqa: D102
pass


class DagNode(ABC):
DagNodeT = TypeVar("DagNodeT", bound="DagNode")


@dataclass(frozen=True)
class DagNode(MetricFlowPrettyFormattable, Generic[DagNodeT], ABC):
"""A node in a DAG. These should be immutable."""

def __init__(self, node_id: NodeId) -> None: # noqa: D107
self._node_id = node_id
parent_nodes: Tuple[DagNodeT, ...]

def __post_init__(self) -> None: # noqa: D105
object.__setattr__(self, "_post_init_node_id", self.create_unique_id())

@property
def node_id(self) -> NodeId:
"""ID for uniquely identifying a given node."""
return self._node_id
"""ID for uniquely identifying a given node.
Ideally, this field would have a default value. However, setting a default field in this class means that all
subclasses would have to have default values for all the fields as default fields must come at the end.
This issue is resolved in Python 3.10 with `kw_only`, so this can be updated once this project's minimum Python
version is 3.10.
Set via `__setattr___` in `__post__init__` to workaround limitations of frozen dataclasses.
"""
return getattr(self, "_post_init_node_id")

@property
@abstractmethod
Expand All @@ -85,14 +101,6 @@ def graphviz_label(self) -> str:
properties=self.displayed_properties,
)

@property
@abstractmethod
def parent_nodes(self) -> Sequence[DagNode]: # noqa: D102
pass

def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}(node_id={self.node_id})"

@classmethod
@abstractmethod
def id_prefix(cls) -> IdPrefix:
Expand All @@ -115,6 +123,11 @@ def structure_text(self, formatter: MetricFlowDagTextFormatter = MetricFlowDagTe
"""Return a text representation that shows the structure of the DAG component starting from this node."""
return formatter.dag_component_to_text(self)

@property
@override
def pretty_format(self) -> Optional[str]:
return f"{self.__class__.__name__}(node_id={self.node_id.id_str})"


def make_graphviz_label(
title: str, properties: Sequence[DisplayedProperty], title_font_size: int = 12, property_font_size: int = 6
Expand Down Expand Up @@ -175,9 +188,6 @@ def from_id_prefix(id_prefix: IdPrefix) -> DagId: # noqa: D102
return DagId(id_str=SequentialIdGenerator.create_next_id(id_prefix).str_value)


DagNodeT = TypeVar("DagNodeT", bound=DagNode)


class MetricFlowDag(Generic[DagNodeT]):
"""Represents a directed acyclic graph. The sink nodes will have the connected components."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,19 @@ def _build_dag_component_for_metric(
)

source_candidates_for_measure_nodes = tuple(
MeasureGroupByItemSourceNode(
MeasureGroupByItemSourceNode.create(
measure_reference=measure_reference,
child_metric_reference=metric_reference,
)
for measure_reference in measure_references_for_metric
)
return MetricGroupByItemResolutionNode(
return MetricGroupByItemResolutionNode.create(
metric_reference=metric_reference,
metric_input_location=metric_input_location,
parent_nodes=source_candidates_for_measure_nodes,
)
# For a derived metric, the parents are other metrics.
return MetricGroupByItemResolutionNode(
return MetricGroupByItemResolutionNode.create(
metric_reference=metric_reference,
metric_input_location=metric_input_location,
parent_nodes=tuple(
Expand All @@ -88,12 +88,12 @@ def _build_dag_component_for_query(
) -> QueryGroupByItemResolutionNode:
"""Builds a DAG component that represents the resolution flow for a query."""
if len(metric_references) == 0:
return QueryGroupByItemResolutionNode(
parent_nodes=(NoMetricsGroupByItemSourceNode(),),
return QueryGroupByItemResolutionNode.create(
parent_nodes=(NoMetricsGroupByItemSourceNode.create(),),
metrics_in_query=metric_references,
where_filter_intersection=where_filter_intersection,
)
return QueryGroupByItemResolutionNode(
return QueryGroupByItemResolutionNode.create(
parent_nodes=tuple(
self._build_dag_component_for_metric(
metric_reference=metric_reference,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import itertools
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Sequence, Tuple
from typing import TYPE_CHECKING, Generic, Tuple

from typing_extensions import override

from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.dag.mf_dag import DagNode, NodeId
from metricflow_semantics.dag.mf_dag import DagNode
from metricflow_semantics.visitor import Visitable, VisitorOutputT

if TYPE_CHECKING:
Expand All @@ -26,14 +26,14 @@
)


class GroupByItemResolutionNode(DagNode, Visitable, ABC):
@dataclass(frozen=True)
class GroupByItemResolutionNode(DagNode["GroupByItemResolutionNode"], Visitable, ABC):
"""Base node type for nodes in a GroupByItemResolutionDag.
See GroupByItemResolutionDag for more details.
"""

def __init__(self) -> None: # noqa: D107
super().__init__(node_id=NodeId.create_unique(self.__class__.id_prefix()))
parent_nodes: Tuple[GroupByItemResolutionNode, ...]

@abstractmethod
def accept(self, visitor: GroupByItemResolutionNodeVisitor[VisitorOutputT]) -> VisitorOutputT:
Expand All @@ -46,11 +46,6 @@ def ui_description(self) -> str:
"""A string that can be used to describe this node as a path element in the UI."""
raise NotImplementedError

@property
@abstractmethod
def parent_nodes(self) -> Sequence[GroupByItemResolutionNode]: # noqa: D102
raise NotImplementedError

@abstractmethod
def _self_set(self) -> GroupByItemResolutionNodeSet:
"""Return a `GroupByItemResolutionNodeInclusiveAncestorSet` only containing self.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Sequence

from dbt_semantic_interfaces.references import MeasureReference, MetricReference
Expand All @@ -15,23 +16,32 @@
from metricflow_semantics.visitor import VisitorOutputT


@dataclass(frozen=True)
class MeasureGroupByItemSourceNode(GroupByItemResolutionNode):
"""Outputs group-by-items for a measure."""
"""Outputs group-by-items for a measure.
def __init__(
self,
Attributes:
measure_reference: Get the group-by items for this measure.
child_metric_reference: The metric that uses this measure.
"""

measure_reference: MeasureReference
child_metric_reference: MetricReference

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
assert len(self.parent_nodes) == 0

@staticmethod
def create( # noqa: D102
measure_reference: MeasureReference,
child_metric_reference: MetricReference,
) -> None:
"""Initializer.
Args:
measure_reference: Get the group-by items for this measure.
child_metric_reference: The metric that uses this measure.
"""
self._measure_reference = measure_reference
self._child_metric_reference = child_metric_reference
super().__init__()
) -> MeasureGroupByItemSourceNode:
return MeasureGroupByItemSourceNode(
parent_nodes=(),
measure_reference=measure_reference,
child_metric_reference=child_metric_reference,
)

@override
def accept(self, visitor: GroupByItemResolutionNodeVisitor[VisitorOutputT]) -> VisitorOutputT:
Expand All @@ -42,11 +52,6 @@ def accept(self, visitor: GroupByItemResolutionNodeVisitor[VisitorOutputT]) -> V
def description(self) -> str:
return "Output group-by-items available for this measure."

@property
@override
def parent_nodes(self) -> Sequence[GroupByItemResolutionNode]:
return ()

@classmethod
@override
def id_prefix(cls) -> IdPrefix:
Expand All @@ -58,23 +63,14 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]:
return tuple(super().displayed_properties) + (
DisplayedProperty(
key="measure_reference",
value=str(self._measure_reference),
value=str(self.measure_reference),
),
DisplayedProperty(
key="child_metric_reference",
value=str(self._child_metric_reference),
value=str(self.child_metric_reference),
),
)

@property
def measure_reference(self) -> MeasureReference: # noqa: D102
return self._measure_reference

@property
def child_metric_reference(self) -> MetricReference:
"""Return the metric that uses this measure."""
return self._child_metric_reference

@property
@override
def ui_description(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from typing import Optional, Sequence, Union
from dataclasses import dataclass
from typing import Optional, Sequence, Tuple, Union

from dbt_semantic_interfaces.references import MetricReference
from typing_extensions import Self, override
from typing_extensions import override

from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
Expand All @@ -19,27 +20,31 @@
from metricflow_semantics.visitor import VisitorOutputT


@dataclass(frozen=True)
class MetricGroupByItemResolutionNode(GroupByItemResolutionNode):
"""Outputs group-by-items relevant to a metric based on the input group-by-items."""
"""Outputs group-by-items relevant to a metric based on the input group-by-items.
def __init__(
self,
Attributes:
metric_reference: The metric that this represents.
metric_input_location: If this is an input metric for a derived metric, the location within the derived metric definition.
parent_nodes: The parent nodes of this metric.
"""

metric_reference: MetricReference
metric_input_location: Optional[InputMetricDefinitionLocation]
parent_nodes: Tuple[Union[MeasureGroupByItemSourceNode, MetricGroupByItemResolutionNode], ...]

@staticmethod
def create( # noqa: D102
metric_reference: MetricReference,
metric_input_location: Optional[InputMetricDefinitionLocation],
parent_nodes: Sequence[Union[MeasureGroupByItemSourceNode, Self]],
) -> None:
"""Initializer.
Args:
metric_reference: The metric that this represents.
metric_input_location: If this is an input metric for a derived metric, the location within the derived
metric definition.
parent_nodes: The parent nodes of this metric.
"""
self._metric_reference = metric_reference
self._metric_input_location = metric_input_location
self._parent_nodes = parent_nodes
super().__init__()
parent_nodes: Sequence[Union[MeasureGroupByItemSourceNode, MetricGroupByItemResolutionNode]],
) -> MetricGroupByItemResolutionNode:
return MetricGroupByItemResolutionNode(
metric_reference=metric_reference,
metric_input_location=metric_input_location,
parent_nodes=tuple(parent_nodes),
)

@override
def accept(self, visitor: GroupByItemResolutionNodeVisitor[VisitorOutputT]) -> VisitorOutputT:
Expand All @@ -50,11 +55,6 @@ def accept(self, visitor: GroupByItemResolutionNodeVisitor[VisitorOutputT]) -> V
def description(self) -> str:
return "Output group-by-items available for this metric."

@property
@override
def parent_nodes(self) -> Sequence[Union[MeasureGroupByItemSourceNode, Self]]:
return self._parent_nodes

@classmethod
@override
def id_prefix(cls) -> IdPrefix:
Expand All @@ -66,26 +66,18 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]:
return tuple(super().displayed_properties) + (
DisplayedProperty(
key="metric_reference",
value=str(self._metric_reference),
value=str(self.metric_reference),
),
)

@property
def metric_reference(self) -> MetricReference: # noqa: D102
return self._metric_reference

@property
def metric_input_location(self) -> Optional[InputMetricDefinitionLocation]: # noqa: D102
return self._metric_input_location

@property
@override
def ui_description(self) -> str:
if self._metric_input_location is None:
return f"Metric({repr(self._metric_reference.element_name)})"
if self.metric_input_location is None:
return f"Metric({repr(self.metric_reference.element_name)})"
return (
f"Metric({repr(self._metric_reference.element_name)}, "
f"input_metric_index={self._metric_input_location.input_metric_list_index})"
f"Metric({repr(self.metric_reference.element_name)}, "
f"input_metric_index={self.metric_input_location.input_metric_list_index})"
)

@override
Expand Down
Loading

0 comments on commit 072b8d5

Please sign in to comment.