Skip to content

Commit

Permalink
Add eq=False to node classes (#1500)
Browse files Browse the repository at this point in the history
This sets `eq=False` for node dataclasses as nodes should be unique, and
equivalence methods should be used for checking as the nodes represent
recursive data structures. Without this, comparisons or using nodes as
keys can be slow as the generated equals function will traverse the
recursive structure.
  • Loading branch information
plypaul authored Nov 9, 2024
1 parent 647a8be commit de6ed9d
Show file tree
Hide file tree
Showing 24 changed files with 54 additions and 48 deletions.
9 changes: 7 additions & 2 deletions metricflow-semantics/metricflow_semantics/dag/mf_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,14 @@ def visit_node(self, node: DagNode) -> VisitorOutputT: # noqa: D102
DagNodeT = TypeVar("DagNodeT", bound="DagNode")


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class DagNode(MetricFlowPrettyFormattable, Generic[DagNodeT], ABC):
"""A node in a DAG. These should be immutable."""
"""A node in a DAG. These should be immutable.
Since there should only be a single instance of a node with a given ID, `eq` can be set to false so that equality
operations can be done without comparing the fields. Comparing the fields can be a slow process since the
`parent_nodes` field is recursive.
"""

parent_nodes: Tuple[DagNodeT, ...]

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
NodeSelfT = TypeVar("NodeSelfT", bound="DataflowPlanNode")


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class DataflowPlanNode(DagNode["DataflowPlanNode"], Visitable, ABC):
"""A node in the graph representation of the dataflow.
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/add_generated_uuid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class AddGeneratedUuidColumnNode(DataflowPlanNode):
"""Adds a UUID column."""

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/aggregate_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class AggregateMeasuresNode(DataflowPlanNode):
"""A node that aggregates the measures by the associated group by elements.
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/combine_aggregated_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class CombineAggregatedOutputsNode(DataflowPlanNode):
"""Combines metrics from different nodes into a single output."""

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/compute_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class ComputeMetricsNode(DataflowPlanNode):
"""A node that computes metrics from input measures. Dimensions / entities are passed through.
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/constrain_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from metricflow.dataflow.nodes.aggregate_measures import DataflowPlanNode


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class ConstrainTimeRangeNode(DataflowPlanNode):
"""Constrains the time range of the input data set.
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/filter_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class FilterElementsNode(DataflowPlanNode):
"""Only passes the listed elements.
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/join_conversion_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class JoinConversionEventsNode(DataflowPlanNode):
"""Builds a data set containing successful conversion events.
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/join_over_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class JoinOverTimeRangeNode(DataflowPlanNode):
"""A node that allows for cumulative metric computation by doing a self join across a cumulative date range.
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/join_to_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __post_init__(self) -> None: # noqa: D105
raise RuntimeError("`join_on_entity` is required unless using CROSS JOIN.")


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class JoinOnEntitiesNode(DataflowPlanNode):
"""A node that joins data from other nodes via the entities in the inputs.
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/join_to_custom_granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class JoinToCustomGranularityNode(DataflowPlanNode, ABC):
"""Join parent dataset to time spine dataset to convert time dimension to a custom granularity.
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/join_to_time_spine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class JoinToTimeSpineNode(DataflowPlanNode, ABC):
"""Join parent dataset to time spine dataset.
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/metric_time_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class MetricTimeDimensionTransformNode(DataflowPlanNode):
"""A node transforms the input data set so that it contains the metric time dimension and relevant measures.
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class MinMaxNode(DataflowPlanNode):
"""Calculate the min and max of a single instance data set."""

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/order_by_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class OrderByLimitNode(DataflowPlanNode):
"""A node that re-orders the input data with a limit.
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/read_sql_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from metricflow.dataset.sql_dataset import SqlDataSet


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class ReadSqlSourceNode(DataflowPlanNode):
"""A source node where data from an SQL table or SQL query is read and output.
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/semi_additive_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SemiAdditiveJoinNode(DataflowPlanNode):
"""A node that performs a row filter by aggregating a given non-additive dimension.
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/where_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class WhereConstraintNode(DataflowPlanNode):
"""Remove rows using a WHERE clause.
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/window_reaggregation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class WindowReaggregationNode(DataflowPlanNode):
"""A node that re-aggregates metrics using window functions.
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/write_to_data_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class WriteToResultDataTableNode(DataflowPlanNode):
"""A node where incoming data gets written to a data_table."""

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/write_to_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class WriteToResultTableNode(DataflowPlanNode):
"""A node where incoming data gets written to a table.
Expand Down
39 changes: 20 additions & 19 deletions metricflow/sql/sql_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing_extensions import override


@dataclass(frozen=True, order=True)
@dataclass(frozen=True, eq=False)
class SqlExpressionNode(DagNode["SqlExpressionNode"], Visitable, ABC):
"""An SQL expression like my_table.my_column, CONCAT(a, b) or 1 + 1 that evaluates to a value."""

Expand Down Expand Up @@ -230,7 +230,7 @@ def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> VisitorOu
pass


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlStringExpression(SqlExpressionNode):
"""An SQL expression in a string format, so it lacks information about the structure.
Expand Down Expand Up @@ -314,7 +314,7 @@ def as_string_expression(self) -> Optional[SqlStringExpression]:
return self


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlStringLiteralExpression(SqlExpressionNode):
"""A string literal like 'foo'. It shouldn't include delimiters as it should be added during rendering."""

Expand Down Expand Up @@ -375,7 +375,7 @@ class SqlColumnReference:
column_name: str


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlColumnReferenceExpression(SqlExpressionNode):
"""An expression that evaluates to the value of a column in one of the sources in the select query.
Expand Down Expand Up @@ -475,7 +475,7 @@ def from_table_and_column_names(table_alias: str, column_name: str) -> SqlColumn
return SqlColumnReferenceExpression.create(SqlColumnReference(table_alias=table_alias, column_name=column_name))


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlColumnAliasReferenceExpression(SqlExpressionNode):
"""An expression that evaluates to the alias of a column, but is not qualified with a table alias.
Expand Down Expand Up @@ -544,7 +544,7 @@ class SqlComparison(Enum): # noqa: D101
EQUALS = "="


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlComparisonExpression(SqlExpressionNode):
"""A comparison using >, <, <=, >=, =.
Expand Down Expand Up @@ -698,6 +698,7 @@ def from_aggregation_type(aggregation_type: AggregationType) -> SqlFunction:
assert_values_exhausted(aggregation_type)


@dataclass(frozen=True, eq=False)
class SqlFunctionExpression(SqlExpressionNode):
"""Denotes a function expression in SQL."""

Expand All @@ -723,7 +724,7 @@ def build_expression_from_aggregation_type(
return SqlAggregateFunctionExpression.from_aggregation_type(aggregation_type, sql_column_expression)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlAggregateFunctionExpression(SqlFunctionExpression):
"""An aggregate function expression like SUM(1).
Expand Down Expand Up @@ -857,7 +858,7 @@ def from_aggregation_parameters(agg_params: MeasureAggregationParameters) -> Sql
)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlPercentileExpression(SqlFunctionExpression):
"""A percentile aggregation expression.
Expand Down Expand Up @@ -984,7 +985,7 @@ def suffix(self) -> str:
return " ".join(result)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlWindowFunctionExpression(SqlFunctionExpression):
"""A window function expression like SUM(foo) OVER bar.
Expand Down Expand Up @@ -1101,7 +1102,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlNullExpression(SqlExpressionNode):
"""Represents NULL."""

Expand Down Expand Up @@ -1151,7 +1152,7 @@ class SqlLogicalOperator(Enum):
OR = "OR"


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlLogicalExpression(SqlExpressionNode):
"""A logical expression like "a AND b AND c"."""

Expand Down Expand Up @@ -1203,7 +1204,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self.operator == other.operator and self._parents_match(other)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlIsNullExpression(SqlExpressionNode):
"""An IS NULL expression like "foo IS NULL"."""

Expand Down Expand Up @@ -1248,7 +1249,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self._parents_match(other)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlSubtractTimeIntervalExpression(SqlExpressionNode):
"""Represents an interval subtraction from a given timestamp.
Expand Down Expand Up @@ -1313,7 +1314,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self.count == other.count and self.granularity == other.granularity and self._parents_match(other)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlCastToTimestampExpression(SqlExpressionNode):
"""Cast to the timestamp type like CAST('2020-01-01' AS TIMESTAMP)."""

Expand Down Expand Up @@ -1360,7 +1361,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self._parents_match(other)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlDateTruncExpression(SqlExpressionNode):
"""Apply a date trunc to a column like CAST('2020-01-01' AS TIMESTAMP)."""

Expand Down Expand Up @@ -1411,7 +1412,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self.time_granularity == other.time_granularity and self._parents_match(other)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlExtractExpression(SqlExpressionNode):
"""Extract a date part from a time expression.
Expand Down Expand Up @@ -1470,7 +1471,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self.date_part == other.date_part and self._parents_match(other)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlRatioComputationExpression(SqlExpressionNode):
"""Node for expressing Ratio metrics to allow for appropriate casting to float/double in each engine.
Expand Down Expand Up @@ -1535,7 +1536,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self._parents_match(other)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlBetweenExpression(SqlExpressionNode):
"""A BETWEEN clause like `column BETWEEN val1 AND val2`.
Expand Down Expand Up @@ -1600,7 +1601,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self._parents_match(other)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlGenerateUuidExpression(SqlExpressionNode):
"""Renders a SQL to generate a random UUID, which is non-deterministic."""

Expand Down
Loading

0 comments on commit de6ed9d

Please sign in to comment.