diff --git a/metricflow-semantics/metricflow_semantics/dag/mf_dag.py b/metricflow-semantics/metricflow_semantics/dag/mf_dag.py index a3a24774a2..7b975d0a28 100644 --- a/metricflow-semantics/metricflow_semantics/dag/mf_dag.py +++ b/metricflow-semantics/metricflow_semantics/dag/mf_dag.py @@ -57,9 +57,14 @@ def visit_node(self, node: DagNode) -> VisitorOutputT: # noqa: D102 DagNodeT = TypeVar("DagNodeT", bound="DagNode") -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class DagNode(MetricFlowPrettyFormattable, Generic[DagNodeT], ABC): - """A node in a DAG. These should be immutable.""" + """A node in a DAG. These should be immutable. + + Since there should only be a single instance of a node with a given ID, `eq` can be set to false so that equality + operations can be done without comparing the fields. Comparing the fields can be a slow process since the + `parent_nodes` field is recursive. + """ parent_nodes: Tuple[DagNodeT, ...] diff --git a/metricflow/dataflow/dataflow_plan.py b/metricflow/dataflow/dataflow_plan.py index b986d7b8f6..dd5f1088b9 100644 --- a/metricflow/dataflow/dataflow_plan.py +++ b/metricflow/dataflow/dataflow_plan.py @@ -44,7 +44,7 @@ NodeSelfT = TypeVar("NodeSelfT", bound="DataflowPlanNode") -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class DataflowPlanNode(DagNode["DataflowPlanNode"], Visitable, ABC): """A node in the graph representation of the dataflow. diff --git a/metricflow/dataflow/nodes/add_generated_uuid.py b/metricflow/dataflow/nodes/add_generated_uuid.py index 6a5a1c2b9f..96df59388d 100644 --- a/metricflow/dataflow/nodes/add_generated_uuid.py +++ b/metricflow/dataflow/nodes/add_generated_uuid.py @@ -10,7 +10,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class AddGeneratedUuidColumnNode(DataflowPlanNode): """Adds a UUID column.""" diff --git a/metricflow/dataflow/nodes/aggregate_measures.py b/metricflow/dataflow/nodes/aggregate_measures.py index 9128f6a0a6..7fa8c153ce 100644 --- a/metricflow/dataflow/nodes/aggregate_measures.py +++ b/metricflow/dataflow/nodes/aggregate_measures.py @@ -10,7 +10,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class AggregateMeasuresNode(DataflowPlanNode): """A node that aggregates the measures by the associated group by elements. diff --git a/metricflow/dataflow/nodes/combine_aggregated_outputs.py b/metricflow/dataflow/nodes/combine_aggregated_outputs.py index 0f022ec8a2..c1c3ad2e3f 100644 --- a/metricflow/dataflow/nodes/combine_aggregated_outputs.py +++ b/metricflow/dataflow/nodes/combine_aggregated_outputs.py @@ -12,7 +12,7 @@ ) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class CombineAggregatedOutputsNode(DataflowPlanNode): """Combines metrics from different nodes into a single output.""" diff --git a/metricflow/dataflow/nodes/compute_metrics.py b/metricflow/dataflow/nodes/compute_metrics.py index 9d4ad3dd92..3220d4374d 100644 --- a/metricflow/dataflow/nodes/compute_metrics.py +++ b/metricflow/dataflow/nodes/compute_metrics.py @@ -16,7 +16,7 @@ ) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class ComputeMetricsNode(DataflowPlanNode): """A node that computes metrics from input measures. Dimensions / entities are passed through. diff --git a/metricflow/dataflow/nodes/constrain_time.py b/metricflow/dataflow/nodes/constrain_time.py index 7ca0ace50b..9694039a44 100644 --- a/metricflow/dataflow/nodes/constrain_time.py +++ b/metricflow/dataflow/nodes/constrain_time.py @@ -12,7 +12,7 @@ from metricflow.dataflow.nodes.aggregate_measures import DataflowPlanNode -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class ConstrainTimeRangeNode(DataflowPlanNode): """Constrains the time range of the input data set. diff --git a/metricflow/dataflow/nodes/filter_elements.py b/metricflow/dataflow/nodes/filter_elements.py index e38ef4e0c2..9605f58bdd 100644 --- a/metricflow/dataflow/nodes/filter_elements.py +++ b/metricflow/dataflow/nodes/filter_elements.py @@ -12,7 +12,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class FilterElementsNode(DataflowPlanNode): """Only passes the listed elements. diff --git a/metricflow/dataflow/nodes/join_conversion_events.py b/metricflow/dataflow/nodes/join_conversion_events.py index ed42661222..cc6530996b 100644 --- a/metricflow/dataflow/nodes/join_conversion_events.py +++ b/metricflow/dataflow/nodes/join_conversion_events.py @@ -16,7 +16,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class JoinConversionEventsNode(DataflowPlanNode): """Builds a data set containing successful conversion events. diff --git a/metricflow/dataflow/nodes/join_over_time.py b/metricflow/dataflow/nodes/join_over_time.py index b137175472..92087cf97b 100644 --- a/metricflow/dataflow/nodes/join_over_time.py +++ b/metricflow/dataflow/nodes/join_over_time.py @@ -14,7 +14,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class JoinOverTimeRangeNode(DataflowPlanNode): """A node that allows for cumulative metric computation by doing a self join across a cumulative date range. diff --git a/metricflow/dataflow/nodes/join_to_base.py b/metricflow/dataflow/nodes/join_to_base.py index 2a6b2c458e..50791c580b 100644 --- a/metricflow/dataflow/nodes/join_to_base.py +++ b/metricflow/dataflow/nodes/join_to_base.py @@ -43,7 +43,7 @@ def __post_init__(self) -> None: # noqa: D105 raise RuntimeError("`join_on_entity` is required unless using CROSS JOIN.") -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class JoinOnEntitiesNode(DataflowPlanNode): """A node that joins data from other nodes via the entities in the inputs. diff --git a/metricflow/dataflow/nodes/join_to_custom_granularity.py b/metricflow/dataflow/nodes/join_to_custom_granularity.py index 6f13f6ece4..2e70b36037 100644 --- a/metricflow/dataflow/nodes/join_to_custom_granularity.py +++ b/metricflow/dataflow/nodes/join_to_custom_granularity.py @@ -12,7 +12,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class JoinToCustomGranularityNode(DataflowPlanNode, ABC): """Join parent dataset to time spine dataset to convert time dimension to a custom granularity. diff --git a/metricflow/dataflow/nodes/join_to_time_spine.py b/metricflow/dataflow/nodes/join_to_time_spine.py index 00633a0fa0..dfc0f10151 100644 --- a/metricflow/dataflow/nodes/join_to_time_spine.py +++ b/metricflow/dataflow/nodes/join_to_time_spine.py @@ -17,7 +17,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class JoinToTimeSpineNode(DataflowPlanNode, ABC): """Join parent dataset to time spine dataset. diff --git a/metricflow/dataflow/nodes/metric_time_transform.py b/metricflow/dataflow/nodes/metric_time_transform.py index 47e5df2ffd..5687904d76 100644 --- a/metricflow/dataflow/nodes/metric_time_transform.py +++ b/metricflow/dataflow/nodes/metric_time_transform.py @@ -11,7 +11,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class MetricTimeDimensionTransformNode(DataflowPlanNode): """A node transforms the input data set so that it contains the metric time dimension and relevant measures. diff --git a/metricflow/dataflow/nodes/min_max.py b/metricflow/dataflow/nodes/min_max.py index 40fa160739..c7713185f5 100644 --- a/metricflow/dataflow/nodes/min_max.py +++ b/metricflow/dataflow/nodes/min_max.py @@ -9,7 +9,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class MinMaxNode(DataflowPlanNode): """Calculate the min and max of a single instance data set.""" diff --git a/metricflow/dataflow/nodes/order_by_limit.py b/metricflow/dataflow/nodes/order_by_limit.py index 0bb1c77b99..f7cbacdf0c 100644 --- a/metricflow/dataflow/nodes/order_by_limit.py +++ b/metricflow/dataflow/nodes/order_by_limit.py @@ -14,7 +14,7 @@ ) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class OrderByLimitNode(DataflowPlanNode): """A node that re-orders the input data with a limit. diff --git a/metricflow/dataflow/nodes/read_sql_source.py b/metricflow/dataflow/nodes/read_sql_source.py index de1da2f604..57a272dffc 100644 --- a/metricflow/dataflow/nodes/read_sql_source.py +++ b/metricflow/dataflow/nodes/read_sql_source.py @@ -15,7 +15,7 @@ from metricflow.dataset.sql_dataset import SqlDataSet -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class ReadSqlSourceNode(DataflowPlanNode): """A source node where data from an SQL table or SQL query is read and output. diff --git a/metricflow/dataflow/nodes/semi_additive_join.py b/metricflow/dataflow/nodes/semi_additive_join.py index 3eaff4d88c..1334cde336 100644 --- a/metricflow/dataflow/nodes/semi_additive_join.py +++ b/metricflow/dataflow/nodes/semi_additive_join.py @@ -13,7 +13,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SemiAdditiveJoinNode(DataflowPlanNode): """A node that performs a row filter by aggregating a given non-additive dimension. diff --git a/metricflow/dataflow/nodes/where_filter.py b/metricflow/dataflow/nodes/where_filter.py index 7b0bef6cda..1152376a1d 100644 --- a/metricflow/dataflow/nodes/where_filter.py +++ b/metricflow/dataflow/nodes/where_filter.py @@ -11,7 +11,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class WhereConstraintNode(DataflowPlanNode): """Remove rows using a WHERE clause. diff --git a/metricflow/dataflow/nodes/window_reaggregation_node.py b/metricflow/dataflow/nodes/window_reaggregation_node.py index 93dbf950ba..3bbe202c9f 100644 --- a/metricflow/dataflow/nodes/window_reaggregation_node.py +++ b/metricflow/dataflow/nodes/window_reaggregation_node.py @@ -17,7 +17,7 @@ from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class WindowReaggregationNode(DataflowPlanNode): """A node that re-aggregates metrics using window functions. diff --git a/metricflow/dataflow/nodes/write_to_data_table.py b/metricflow/dataflow/nodes/write_to_data_table.py index 39f6eb0fb0..66701d7399 100644 --- a/metricflow/dataflow/nodes/write_to_data_table.py +++ b/metricflow/dataflow/nodes/write_to_data_table.py @@ -12,7 +12,7 @@ ) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class WriteToResultDataTableNode(DataflowPlanNode): """A node where incoming data gets written to a data_table.""" diff --git a/metricflow/dataflow/nodes/write_to_table.py b/metricflow/dataflow/nodes/write_to_table.py index a17c4bd7a7..7a55a5724d 100644 --- a/metricflow/dataflow/nodes/write_to_table.py +++ b/metricflow/dataflow/nodes/write_to_table.py @@ -13,7 +13,7 @@ ) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class WriteToResultTableNode(DataflowPlanNode): """A node where incoming data gets written to a table. diff --git a/metricflow/sql/sql_exprs.py b/metricflow/sql/sql_exprs.py index 775e962a8b..62d8e96874 100644 --- a/metricflow/sql/sql_exprs.py +++ b/metricflow/sql/sql_exprs.py @@ -22,7 +22,7 @@ from typing_extensions import override -@dataclass(frozen=True, order=True) +@dataclass(frozen=True, eq=False) class SqlExpressionNode(DagNode["SqlExpressionNode"], Visitable, ABC): """An SQL expression like my_table.my_column, CONCAT(a, b) or 1 + 1 that evaluates to a value.""" @@ -230,7 +230,7 @@ def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> VisitorOu pass -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlStringExpression(SqlExpressionNode): """An SQL expression in a string format, so it lacks information about the structure. @@ -314,7 +314,7 @@ def as_string_expression(self) -> Optional[SqlStringExpression]: return self -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlStringLiteralExpression(SqlExpressionNode): """A string literal like 'foo'. It shouldn't include delimiters as it should be added during rendering.""" @@ -375,7 +375,7 @@ class SqlColumnReference: column_name: str -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlColumnReferenceExpression(SqlExpressionNode): """An expression that evaluates to the value of a column in one of the sources in the select query. @@ -475,7 +475,7 @@ def from_table_and_column_names(table_alias: str, column_name: str) -> SqlColumn return SqlColumnReferenceExpression.create(SqlColumnReference(table_alias=table_alias, column_name=column_name)) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlColumnAliasReferenceExpression(SqlExpressionNode): """An expression that evaluates to the alias of a column, but is not qualified with a table alias. @@ -544,7 +544,7 @@ class SqlComparison(Enum): # noqa: D101 EQUALS = "=" -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlComparisonExpression(SqlExpressionNode): """A comparison using >, <, <=, >=, =. @@ -698,6 +698,7 @@ def from_aggregation_type(aggregation_type: AggregationType) -> SqlFunction: assert_values_exhausted(aggregation_type) +@dataclass(frozen=True, eq=False) class SqlFunctionExpression(SqlExpressionNode): """Denotes a function expression in SQL.""" @@ -723,7 +724,7 @@ def build_expression_from_aggregation_type( return SqlAggregateFunctionExpression.from_aggregation_type(aggregation_type, sql_column_expression) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlAggregateFunctionExpression(SqlFunctionExpression): """An aggregate function expression like SUM(1). @@ -857,7 +858,7 @@ def from_aggregation_parameters(agg_params: MeasureAggregationParameters) -> Sql ) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlPercentileExpression(SqlFunctionExpression): """A percentile aggregation expression. @@ -984,7 +985,7 @@ def suffix(self) -> str: return " ".join(result) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlWindowFunctionExpression(SqlFunctionExpression): """A window function expression like SUM(foo) OVER bar. @@ -1101,7 +1102,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 ) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlNullExpression(SqlExpressionNode): """Represents NULL.""" @@ -1151,7 +1152,7 @@ class SqlLogicalOperator(Enum): OR = "OR" -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlLogicalExpression(SqlExpressionNode): """A logical expression like "a AND b AND c".""" @@ -1203,7 +1204,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self.operator == other.operator and self._parents_match(other) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlIsNullExpression(SqlExpressionNode): """An IS NULL expression like "foo IS NULL".""" @@ -1248,7 +1249,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self._parents_match(other) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlSubtractTimeIntervalExpression(SqlExpressionNode): """Represents an interval subtraction from a given timestamp. @@ -1313,7 +1314,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self.count == other.count and self.granularity == other.granularity and self._parents_match(other) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlCastToTimestampExpression(SqlExpressionNode): """Cast to the timestamp type like CAST('2020-01-01' AS TIMESTAMP).""" @@ -1360,7 +1361,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self._parents_match(other) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlDateTruncExpression(SqlExpressionNode): """Apply a date trunc to a column like CAST('2020-01-01' AS TIMESTAMP).""" @@ -1411,7 +1412,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self.time_granularity == other.time_granularity and self._parents_match(other) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlExtractExpression(SqlExpressionNode): """Extract a date part from a time expression. @@ -1470,7 +1471,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self.date_part == other.date_part and self._parents_match(other) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlRatioComputationExpression(SqlExpressionNode): """Node for expressing Ratio metrics to allow for appropriate casting to float/double in each engine. @@ -1535,7 +1536,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self._parents_match(other) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlBetweenExpression(SqlExpressionNode): """A BETWEEN clause like `column BETWEEN val1 AND val2`. @@ -1600,7 +1601,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self._parents_match(other) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlGenerateUuidExpression(SqlExpressionNode): """Renders a SQL to generate a random UUID, which is non-deterministic.""" diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index ea4a61756b..52a7390429 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlQueryPlanNode(DagNode["SqlQueryPlanNode"], ABC): """Modeling a SQL query plan like a data flow plan as well. @@ -105,7 +105,7 @@ class SqlOrderByDescription: # noqa: D101 desc: bool -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlSelectStatementNode(SqlQueryPlanNode): """Represents an SQL Select statement. @@ -197,7 +197,7 @@ def description(self) -> str: return self._description -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlTableNode(SqlQueryPlanNode): """An SQL table that can go in the FROM clause or the JOIN clause.""" @@ -234,7 +234,7 @@ def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102 return None -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlSelectQueryFromClauseNode(SqlQueryPlanNode): """An SQL select query that can go in the FROM clause. @@ -271,7 +271,7 @@ def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102 return None -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlCreateTableAsNode(SqlQueryPlanNode): """An SQL node representing a CREATE TABLE AS statement. @@ -343,7 +343,7 @@ def render_node(self) -> SqlQueryPlanNode: # noqa: D102 return self._render_node -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlCteNode(SqlQueryPlanNode): """Represents a single common table expression."""