From 072b8d52d320ce66e05165dad868c030227c7f62 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Thu, 11 Jul 2024 10:27:55 -0700 Subject: [PATCH] Convert DAG node classes to dataclasses (#1320) * Previously, the DAG-node classes in MF were not dataclasses due to issues with class inheritance. These have since been resolved via hierarchy changes, so this PR updates those classes to be dataclasses. In general, dataclasses make these easier to use, and there are upcoming use cases where dataclasses will simplify implementation (e.g. graph component comparison, serialization). * A `create()` method was added to simplify many initialization use cases while not overriding the one generated by `dataclasses`. * There is an update to how the `node_id` field is set - please see `mf_dag.py`. * Otherwise, this should be a mechanical update with no substantive logic changes. * There are no snapshot changes, so that should simplify review. * Please view by commit. --- .../metricflow_semantics/dag/mf_dag.py | 44 +- .../resolution_dag/dag_builder.py | 12 +- .../resolution_nodes/base_node.py | 15 +- .../resolution_nodes/measure_source_node.py | 54 +- .../metric_resolution_node.py | 66 +- .../no_metrics_query_source_node.py | 20 +- .../resolution_nodes/query_resolution_node.py | 47 +- .../query/query_resolver.py | 2 +- .../test_matching_item_for_querying.py | 2 +- .../dataflow/builder/dataflow_plan_builder.py | 74 +- metricflow/dataflow/builder/node_evaluator.py | 2 +- metricflow/dataflow/builder/source_node.py | 8 +- metricflow/dataflow/dataflow_plan.py | 21 +- .../dataflow/nodes/add_generated_uuid.py | 14 +- .../dataflow/nodes/aggregate_measures.py | 42 +- .../nodes/combine_aggregated_outputs.py | 17 +- metricflow/dataflow/nodes/compute_metrics.py | 65 +- metricflow/dataflow/nodes/constrain_time.py | 27 +- metricflow/dataflow/nodes/filter_elements.py | 63 +- .../dataflow/nodes/join_conversion_events.py | 110 ++- metricflow/dataflow/nodes/join_over_time.py | 74 +- metricflow/dataflow/nodes/join_to_base.py | 59 +- .../dataflow/nodes/join_to_time_spine.py | 120 ++-- .../dataflow/nodes/metric_time_transform.py | 41 +- metricflow/dataflow/nodes/min_max.py | 16 +- metricflow/dataflow/nodes/order_by_limit.py | 64 +- metricflow/dataflow/nodes/read_sql_source.py | 35 +- .../dataflow/nodes/semi_additive_join.py | 79 +-- metricflow/dataflow/nodes/where_filter.py | 57 +- .../nodes/window_reaggregation_node.py | 54 +- .../dataflow/nodes/write_to_data_table.py | 21 +- metricflow/dataflow/nodes/write_to_table.py | 42 +- .../optimizer/predicate_pushdown_optimizer.py | 4 +- .../source_scan/cm_branch_combiner.py | 6 +- .../source_scan/source_scan_optimizer.py | 4 +- metricflow/dataset/convert_semantic_model.py | 20 +- metricflow/execution/dataflow_to_execution.py | 14 +- metricflow/execution/execution_plan.py | 153 +++-- metricflow/plan_conversion/dataflow_to_sql.py | 122 ++-- .../plan_conversion/instance_converters.py | 12 +- metricflow/plan_conversion/node_processor.py | 10 +- .../sql_expression_builders.py | 6 +- .../plan_conversion/sql_join_builder.py | 64 +- metricflow/sql/optimizer/column_pruner.py | 10 +- .../optimizer/rewriting_sub_query_reducer.py | 28 +- metricflow/sql/optimizer/sub_query_reducer.py | 12 +- .../sql/optimizer/table_alias_simplifier.py | 8 +- metricflow/sql/sql_exprs.py | 650 +++++++++--------- metricflow/sql/sql_plan.py | 213 +++--- .../data_warehouse_model_validator.py | 14 +- scripts/ci_tests/metricflow_package_test.py | 2 +- .../dataflow/builder/test_node_data_set.py | 14 +- .../source_scan/test_cm_branch_combiner.py | 6 +- tests_metricflow/examples/test_node_sql.py | 6 +- tests_metricflow/execution/noop_task.py | 39 +- .../execution/test_sequential_executor.py | 16 +- tests_metricflow/execution/test_tasks.py | 11 +- .../fixtures/manifest_fixtures.py | 2 +- .../integration/test_configured_cases.py | 22 +- .../mf_logging/test_dag_to_text.py | 6 +- .../test_metric_time_dimension_to_sql.py | 4 +- .../test_dataflow_to_sql_plan.py | 120 ++-- .../sql/optimizer/test_column_pruner.py | 194 +++--- .../test_rewriting_sub_query_reducer.py | 272 ++++---- .../sql/optimizer/test_sub_query_reducer.py | 72 +- .../optimizer/test_table_alias_simplifier.py | 32 +- .../sql/test_engine_specific_rendering.py | 56 +- tests_metricflow/sql/test_sql_expr_render.py | 115 ++-- tests_metricflow/sql/test_sql_plan_render.py | 104 +-- .../sql_clients/test_date_time_operations.py | 8 +- 70 files changed, 1931 insertions(+), 1887 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/dag/mf_dag.py b/metricflow-semantics/metricflow_semantics/dag/mf_dag.py index b7bcb0ab20..a3a24774a2 100644 --- a/metricflow-semantics/metricflow_semantics/dag/mf_dag.py +++ b/metricflow-semantics/metricflow_semantics/dag/mf_dag.py @@ -7,13 +7,15 @@ import textwrap from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Generic, Sequence, TypeVar +from typing import Any, Generic, Optional, Sequence, Tuple, TypeVar import jinja2 +from typing_extensions import override from metricflow_semantics.dag.dag_to_text import MetricFlowDagTextFormatter from metricflow_semantics.dag.id_prefix import IdPrefix from metricflow_semantics.dag.sequential_id import SequentialIdGenerator +from metricflow_semantics.mf_logging.pretty_formattable import MetricFlowPrettyFormattable from metricflow_semantics.visitor import VisitorOutputT logger = logging.getLogger(__name__) @@ -52,16 +54,30 @@ def visit_node(self, node: DagNode) -> VisitorOutputT: # noqa: D102 pass -class DagNode(ABC): +DagNodeT = TypeVar("DagNodeT", bound="DagNode") + + +@dataclass(frozen=True) +class DagNode(MetricFlowPrettyFormattable, Generic[DagNodeT], ABC): """A node in a DAG. These should be immutable.""" - def __init__(self, node_id: NodeId) -> None: # noqa: D107 - self._node_id = node_id + parent_nodes: Tuple[DagNodeT, ...] + + def __post_init__(self) -> None: # noqa: D105 + object.__setattr__(self, "_post_init_node_id", self.create_unique_id()) @property def node_id(self) -> NodeId: - """ID for uniquely identifying a given node.""" - return self._node_id + """ID for uniquely identifying a given node. + + Ideally, this field would have a default value. However, setting a default field in this class means that all + subclasses would have to have default values for all the fields as default fields must come at the end. + This issue is resolved in Python 3.10 with `kw_only`, so this can be updated once this project's minimum Python + version is 3.10. + + Set via `__setattr___` in `__post__init__` to workaround limitations of frozen dataclasses. + """ + return getattr(self, "_post_init_node_id") @property @abstractmethod @@ -85,14 +101,6 @@ def graphviz_label(self) -> str: properties=self.displayed_properties, ) - @property - @abstractmethod - def parent_nodes(self) -> Sequence[DagNode]: # noqa: D102 - pass - - def __repr__(self) -> str: # noqa: D105 - return f"{self.__class__.__name__}(node_id={self.node_id})" - @classmethod @abstractmethod def id_prefix(cls) -> IdPrefix: @@ -115,6 +123,11 @@ def structure_text(self, formatter: MetricFlowDagTextFormatter = MetricFlowDagTe """Return a text representation that shows the structure of the DAG component starting from this node.""" return formatter.dag_component_to_text(self) + @property + @override + def pretty_format(self) -> Optional[str]: + return f"{self.__class__.__name__}(node_id={self.node_id.id_str})" + def make_graphviz_label( title: str, properties: Sequence[DisplayedProperty], title_font_size: int = 12, property_font_size: int = 6 @@ -175,9 +188,6 @@ def from_id_prefix(id_prefix: IdPrefix) -> DagId: # noqa: D102 return DagId(id_str=SequentialIdGenerator.create_next_id(id_prefix).str_value) -DagNodeT = TypeVar("DagNodeT", bound=DagNode) - - class MetricFlowDag(Generic[DagNodeT]): """Represents a directed acyclic graph. The sink nodes will have the connected components.""" diff --git a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/dag_builder.py b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/dag_builder.py index c161722907..db2cb72556 100644 --- a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/dag_builder.py +++ b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/dag_builder.py @@ -56,19 +56,19 @@ def _build_dag_component_for_metric( ) source_candidates_for_measure_nodes = tuple( - MeasureGroupByItemSourceNode( + MeasureGroupByItemSourceNode.create( measure_reference=measure_reference, child_metric_reference=metric_reference, ) for measure_reference in measure_references_for_metric ) - return MetricGroupByItemResolutionNode( + return MetricGroupByItemResolutionNode.create( metric_reference=metric_reference, metric_input_location=metric_input_location, parent_nodes=source_candidates_for_measure_nodes, ) # For a derived metric, the parents are other metrics. - return MetricGroupByItemResolutionNode( + return MetricGroupByItemResolutionNode.create( metric_reference=metric_reference, metric_input_location=metric_input_location, parent_nodes=tuple( @@ -88,12 +88,12 @@ def _build_dag_component_for_query( ) -> QueryGroupByItemResolutionNode: """Builds a DAG component that represents the resolution flow for a query.""" if len(metric_references) == 0: - return QueryGroupByItemResolutionNode( - parent_nodes=(NoMetricsGroupByItemSourceNode(),), + return QueryGroupByItemResolutionNode.create( + parent_nodes=(NoMetricsGroupByItemSourceNode.create(),), metrics_in_query=metric_references, where_filter_intersection=where_filter_intersection, ) - return QueryGroupByItemResolutionNode( + return QueryGroupByItemResolutionNode.create( parent_nodes=tuple( self._build_dag_component_for_metric( metric_reference=metric_reference, diff --git a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/base_node.py b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/base_node.py index db6f65e646..8933493e82 100644 --- a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/base_node.py +++ b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/base_node.py @@ -3,12 +3,12 @@ import itertools from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Generic, Sequence, Tuple +from typing import TYPE_CHECKING, Generic, Tuple from typing_extensions import override from metricflow_semantics.collection_helpers.merger import Mergeable -from metricflow_semantics.dag.mf_dag import DagNode, NodeId +from metricflow_semantics.dag.mf_dag import DagNode from metricflow_semantics.visitor import Visitable, VisitorOutputT if TYPE_CHECKING: @@ -26,14 +26,14 @@ ) -class GroupByItemResolutionNode(DagNode, Visitable, ABC): +@dataclass(frozen=True) +class GroupByItemResolutionNode(DagNode["GroupByItemResolutionNode"], Visitable, ABC): """Base node type for nodes in a GroupByItemResolutionDag. See GroupByItemResolutionDag for more details. """ - def __init__(self) -> None: # noqa: D107 - super().__init__(node_id=NodeId.create_unique(self.__class__.id_prefix())) + parent_nodes: Tuple[GroupByItemResolutionNode, ...] @abstractmethod def accept(self, visitor: GroupByItemResolutionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: @@ -46,11 +46,6 @@ def ui_description(self) -> str: """A string that can be used to describe this node as a path element in the UI.""" raise NotImplementedError - @property - @abstractmethod - def parent_nodes(self) -> Sequence[GroupByItemResolutionNode]: # noqa: D102 - raise NotImplementedError - @abstractmethod def _self_set(self) -> GroupByItemResolutionNodeSet: """Return a `GroupByItemResolutionNodeInclusiveAncestorSet` only containing self. diff --git a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/measure_source_node.py b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/measure_source_node.py index 5a2dd593ae..e1d6263718 100644 --- a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/measure_source_node.py +++ b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/measure_source_node.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Sequence from dbt_semantic_interfaces.references import MeasureReference, MetricReference @@ -15,23 +16,32 @@ from metricflow_semantics.visitor import VisitorOutputT +@dataclass(frozen=True) class MeasureGroupByItemSourceNode(GroupByItemResolutionNode): - """Outputs group-by-items for a measure.""" + """Outputs group-by-items for a measure. - def __init__( - self, + Attributes: + measure_reference: Get the group-by items for this measure. + child_metric_reference: The metric that uses this measure. + """ + + measure_reference: MeasureReference + child_metric_reference: MetricReference + + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 0 + + @staticmethod + def create( # noqa: D102 measure_reference: MeasureReference, child_metric_reference: MetricReference, - ) -> None: - """Initializer. - - Args: - measure_reference: Get the group-by items for this measure. - child_metric_reference: The metric that uses this measure. - """ - self._measure_reference = measure_reference - self._child_metric_reference = child_metric_reference - super().__init__() + ) -> MeasureGroupByItemSourceNode: + return MeasureGroupByItemSourceNode( + parent_nodes=(), + measure_reference=measure_reference, + child_metric_reference=child_metric_reference, + ) @override def accept(self, visitor: GroupByItemResolutionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: @@ -42,11 +52,6 @@ def accept(self, visitor: GroupByItemResolutionNodeVisitor[VisitorOutputT]) -> V def description(self) -> str: return "Output group-by-items available for this measure." - @property - @override - def parent_nodes(self) -> Sequence[GroupByItemResolutionNode]: - return () - @classmethod @override def id_prefix(cls) -> IdPrefix: @@ -58,23 +63,14 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]: return tuple(super().displayed_properties) + ( DisplayedProperty( key="measure_reference", - value=str(self._measure_reference), + value=str(self.measure_reference), ), DisplayedProperty( key="child_metric_reference", - value=str(self._child_metric_reference), + value=str(self.child_metric_reference), ), ) - @property - def measure_reference(self) -> MeasureReference: # noqa: D102 - return self._measure_reference - - @property - def child_metric_reference(self) -> MetricReference: - """Return the metric that uses this measure.""" - return self._child_metric_reference - @property @override def ui_description(self) -> str: diff --git a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/metric_resolution_node.py b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/metric_resolution_node.py index 4c3bd9bba5..9825223bd2 100644 --- a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/metric_resolution_node.py +++ b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/metric_resolution_node.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Optional, Sequence, Union +from dataclasses import dataclass +from typing import Optional, Sequence, Tuple, Union from dbt_semantic_interfaces.references import MetricReference -from typing_extensions import Self, override +from typing_extensions import override from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DisplayedProperty @@ -19,27 +20,31 @@ from metricflow_semantics.visitor import VisitorOutputT +@dataclass(frozen=True) class MetricGroupByItemResolutionNode(GroupByItemResolutionNode): - """Outputs group-by-items relevant to a metric based on the input group-by-items.""" + """Outputs group-by-items relevant to a metric based on the input group-by-items. - def __init__( - self, + Attributes: + metric_reference: The metric that this represents. + metric_input_location: If this is an input metric for a derived metric, the location within the derived metric definition. + parent_nodes: The parent nodes of this metric. + """ + + metric_reference: MetricReference + metric_input_location: Optional[InputMetricDefinitionLocation] + parent_nodes: Tuple[Union[MeasureGroupByItemSourceNode, MetricGroupByItemResolutionNode], ...] + + @staticmethod + def create( # noqa: D102 metric_reference: MetricReference, metric_input_location: Optional[InputMetricDefinitionLocation], - parent_nodes: Sequence[Union[MeasureGroupByItemSourceNode, Self]], - ) -> None: - """Initializer. - - Args: - metric_reference: The metric that this represents. - metric_input_location: If this is an input metric for a derived metric, the location within the derived - metric definition. - parent_nodes: The parent nodes of this metric. - """ - self._metric_reference = metric_reference - self._metric_input_location = metric_input_location - self._parent_nodes = parent_nodes - super().__init__() + parent_nodes: Sequence[Union[MeasureGroupByItemSourceNode, MetricGroupByItemResolutionNode]], + ) -> MetricGroupByItemResolutionNode: + return MetricGroupByItemResolutionNode( + metric_reference=metric_reference, + metric_input_location=metric_input_location, + parent_nodes=tuple(parent_nodes), + ) @override def accept(self, visitor: GroupByItemResolutionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: @@ -50,11 +55,6 @@ def accept(self, visitor: GroupByItemResolutionNodeVisitor[VisitorOutputT]) -> V def description(self) -> str: return "Output group-by-items available for this metric." - @property - @override - def parent_nodes(self) -> Sequence[Union[MeasureGroupByItemSourceNode, Self]]: - return self._parent_nodes - @classmethod @override def id_prefix(cls) -> IdPrefix: @@ -66,26 +66,18 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]: return tuple(super().displayed_properties) + ( DisplayedProperty( key="metric_reference", - value=str(self._metric_reference), + value=str(self.metric_reference), ), ) - @property - def metric_reference(self) -> MetricReference: # noqa: D102 - return self._metric_reference - - @property - def metric_input_location(self) -> Optional[InputMetricDefinitionLocation]: # noqa: D102 - return self._metric_input_location - @property @override def ui_description(self) -> str: - if self._metric_input_location is None: - return f"Metric({repr(self._metric_reference.element_name)})" + if self.metric_input_location is None: + return f"Metric({repr(self.metric_reference.element_name)})" return ( - f"Metric({repr(self._metric_reference.element_name)}, " - f"input_metric_index={self._metric_input_location.input_metric_list_index})" + f"Metric({repr(self.metric_reference.element_name)}, " + f"input_metric_index={self.metric_input_location.input_metric_list_index})" ) @override diff --git a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/no_metrics_query_source_node.py b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/no_metrics_query_source_node.py index 015f592b1d..bd55dfc379 100644 --- a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/no_metrics_query_source_node.py +++ b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/no_metrics_query_source_node.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from dataclasses import dataclass from typing_extensions import override @@ -10,17 +10,20 @@ GroupByItemResolutionNodeSet, GroupByItemResolutionNodeVisitor, ) -from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.metric_resolution_node import ( - MetricGroupByItemResolutionNode, -) from metricflow_semantics.visitor import VisitorOutputT +@dataclass(frozen=True) class NoMetricsGroupByItemSourceNode(GroupByItemResolutionNode): """Outputs group-by-items that can be queried without any metrics.""" - def __init__(self) -> None: # noqa: D107 - super().__init__() + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 0 + + @staticmethod + def create() -> NoMetricsGroupByItemSourceNode: # noqa: D102 + return NoMetricsGroupByItemSourceNode(parent_nodes=()) @override def accept(self, visitor: GroupByItemResolutionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: @@ -31,11 +34,6 @@ def accept(self, visitor: GroupByItemResolutionNodeVisitor[VisitorOutputT]) -> V def description(self) -> str: return "Output the available group-by-items for a query without any metrics." - @property - @override - def parent_nodes(self) -> Sequence[MetricGroupByItemResolutionNode]: - return () - @classmethod @override def id_prefix(cls) -> IdPrefix: diff --git a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/query_resolution_node.py b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/query_resolution_node.py index f4b8593608..8525342707 100644 --- a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/query_resolution_node.py +++ b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/query_resolution_node.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import List, Sequence, Union +from dataclasses import dataclass +from typing import List, Sequence, Tuple, Union from dbt_semantic_interfaces.protocols import WhereFilterIntersection from dbt_semantic_interfaces.references import MetricReference @@ -22,19 +23,31 @@ from metricflow_semantics.visitor import VisitorOutputT +@dataclass(frozen=True) class QueryGroupByItemResolutionNode(GroupByItemResolutionNode): - """Output the group-by-items relevant to the query and based on the inputs.""" + """Output the group-by-items relevant to the query and based on the inputs. - def __init__( # noqa: D107 - self, + Attributes: + parent_nodes: The parent nodes of this query. + metrics_in_query: The metrics that are queried in this query. + where_filter_intersection: The intersection of where filters. + """ + + parent_nodes: Tuple[Union[MetricGroupByItemResolutionNode, NoMetricsGroupByItemSourceNode], ...] + metrics_in_query: Tuple[MetricReference, ...] + where_filter_intersection: WhereFilterIntersection + + @staticmethod + def create( # noqa: D102 parent_nodes: Sequence[Union[MetricGroupByItemResolutionNode, NoMetricsGroupByItemSourceNode]], metrics_in_query: Sequence[MetricReference], where_filter_intersection: WhereFilterIntersection, - ) -> None: - self._parent_nodes = tuple(parent_nodes) - self._metrics_in_query = tuple(metrics_in_query) - self._where_filter_intersection = where_filter_intersection - super().__init__() + ) -> QueryGroupByItemResolutionNode: + return QueryGroupByItemResolutionNode( + parent_nodes=tuple(parent_nodes), + metrics_in_query=tuple(metrics_in_query), + where_filter_intersection=where_filter_intersection, + ) @override def accept(self, visitor: GroupByItemResolutionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: @@ -45,21 +58,11 @@ def accept(self, visitor: GroupByItemResolutionNodeVisitor[VisitorOutputT]) -> V def description(self) -> str: return "Output the group-by items for query." - @property - @override - def parent_nodes(self) -> Sequence[Union[MetricGroupByItemResolutionNode, NoMetricsGroupByItemSourceNode]]: - return self._parent_nodes - @classmethod @override def id_prefix(cls) -> IdPrefix: return StaticIdPrefix.QUERY_GROUP_BY_ITEM_RESOLUTION_NODE - @property - def metrics_in_query(self) -> Sequence[MetricReference]: - """Return the metrics that are queried in this query.""" - return self._metrics_in_query - @property @override def displayed_properties(self) -> List[DisplayedProperty]: @@ -85,14 +88,10 @@ def displayed_properties(self) -> List[DisplayedProperty]: return properties - @property - def where_filter_intersection(self) -> WhereFilterIntersection: # noqa: D102 - return self._where_filter_intersection - @property @override def ui_description(self) -> str: - return f"Query({repr([metric_reference.element_name for metric_reference in self._metrics_in_query])})" + return f"Query({repr([metric_reference.element_name for metric_reference in self.metrics_in_query])})" @override def _self_set(self) -> GroupByItemResolutionNodeSet: diff --git a/metricflow-semantics/metricflow_semantics/query/query_resolver.py b/metricflow-semantics/metricflow_semantics/query/query_resolver.py index a1517b5462..d91ed9c17b 100644 --- a/metricflow-semantics/metricflow_semantics/query/query_resolver.py +++ b/metricflow-semantics/metricflow_semantics/query/query_resolver.py @@ -381,7 +381,7 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met # Define a resolution path for issues where the input is considered to be the whole query. query_resolution_path = MetricFlowQueryResolutionPath.from_path_item( - QueryGroupByItemResolutionNode( + QueryGroupByItemResolutionNode.create( parent_nodes=(), metrics_in_query=tuple(metric_input.spec_pattern.metric_reference for metric_input in metric_inputs), where_filter_intersection=query_level_filter_input.where_filter_intersection, diff --git a/metricflow-semantics/tests_metricflow_semantics/query/group_by_item/test_matching_item_for_querying.py b/metricflow-semantics/tests_metricflow_semantics/query/group_by_item/test_matching_item_for_querying.py index 41c0cf6afc..7b9aa5a429 100644 --- a/metricflow-semantics/tests_metricflow_semantics/query/group_by_item/test_matching_item_for_querying.py +++ b/metricflow-semantics/tests_metricflow_semantics/query/group_by_item/test_matching_item_for_querying.py @@ -131,7 +131,7 @@ def test_missing_parent_for_metric( or measures). However, in the event of a validation gap upstream, we sometimes encounter inscrutable errors caused by missing parent nodes for these input types, so we add a more informative error and test for it here. """ - metric_node = MetricGroupByItemResolutionNode( + metric_node = MetricGroupByItemResolutionNode.create( metric_reference=MetricReference(element_name="bad_metric"), metric_input_location=None, parent_nodes=tuple() ) resolution_dag = GroupByItemResolutionDag(sink_node=metric_node) diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 096d63edbf..7cea91608e 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -301,7 +301,7 @@ def _build_aggregated_conversion_node( # Build unaggregated conversions source node # Generate UUID column for conversion source to uniquely identify each row - unaggregated_conversion_measure_node = AddGeneratedUuidColumnNode( + unaggregated_conversion_measure_node = AddGeneratedUuidColumnNode.create( parent_node=conversion_measure_recipe.source_node ) @@ -333,10 +333,10 @@ def _build_aggregated_conversion_node( # Build the unaggregated base measure node for computing conversions unaggregated_base_measure_node = base_measure_recipe.source_node if base_measure_recipe.join_targets: - unaggregated_base_measure_node = JoinOnEntitiesNode( + unaggregated_base_measure_node = JoinOnEntitiesNode.create( left_node=unaggregated_base_measure_node, join_targets=base_measure_recipe.join_targets ) - filtered_unaggregated_base_node = FilterElementsNode( + filtered_unaggregated_base_node = FilterElementsNode.create( parent_node=unaggregated_base_measure_node, include_specs=group_specs_by_type(required_local_specs) .merge(base_required_linkable_specs.as_spec_set) @@ -347,7 +347,7 @@ def _build_aggregated_conversion_node( # The conversion events are joined by the base events which are already time constrained. However, this could # be still be constrained, where we adjust the time range to the window size similar to cumulative, but # adjusted in the opposite direction. - join_conversion_node = JoinConversionEventsNode( + join_conversion_node = JoinConversionEventsNode.create( base_node=filtered_unaggregated_base_node, base_time_dimension_spec=base_time_dimension_spec, conversion_node=unaggregated_conversion_measure_node, @@ -377,7 +377,9 @@ def _build_aggregated_conversion_node( ) # Combine the aggregated opportunities and conversion data sets - return CombineAggregatedOutputsNode(parent_nodes=(aggregated_base_measure_node, aggregated_conversions_node)) + return CombineAggregatedOutputsNode.create( + parent_nodes=(aggregated_base_measure_node, aggregated_conversions_node) + ) def _build_conversion_metric_output_node( self, @@ -468,7 +470,7 @@ def _build_cumulative_metric_output_node( predicate_pushdown_state=predicate_pushdown_state, for_group_by_source_node=for_group_by_source_node, ) - return WindowReaggregationNode( + return WindowReaggregationNode.create( parent_node=compute_metrics_node, metric_spec=metric_spec, order_by_spec=default_metric_time, @@ -609,9 +611,11 @@ def _build_derived_metric_output_node( ) parent_node = ( - parent_nodes[0] if len(parent_nodes) == 1 else CombineAggregatedOutputsNode(parent_nodes=parent_nodes) + parent_nodes[0] + if len(parent_nodes) == 1 + else CombineAggregatedOutputsNode.create(parent_nodes=parent_nodes) ) - output_node: DataflowPlanNode = ComputeMetricsNode( + output_node: DataflowPlanNode = ComputeMetricsNode.create( parent_node=parent_node, metric_specs=[metric_spec], for_group_by_source_node=for_group_by_source_node, @@ -626,7 +630,7 @@ def _build_derived_metric_output_node( assert ( queried_agg_time_dimension_specs ), "Joining to time spine requires querying with metric_time or the appropriate agg_time_dimension." - output_node = JoinToTimeSpineNode( + output_node = JoinToTimeSpineNode.create( parent_node=output_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, @@ -637,9 +641,9 @@ def _build_derived_metric_output_node( ) if len(metric_spec.filter_specs) > 0: - output_node = WhereConstraintNode(parent_node=output_node, where_specs=metric_spec.filter_specs) + output_node = WhereConstraintNode.create(parent_node=output_node, where_specs=metric_spec.filter_specs) if not extraneous_linkable_specs.is_subset_of(queried_linkable_specs): - output_node = FilterElementsNode( + output_node = FilterElementsNode.create( parent_node=output_node, include_specs=InstanceSpecSet(metric_specs=(metric_spec,)).merge( queried_linkable_specs.as_spec_set @@ -734,7 +738,7 @@ def _build_metrics_output_node( if len(output_nodes) == 1: return output_nodes[0] - return CombineAggregatedOutputsNode(parent_nodes=output_nodes) + return CombineAggregatedOutputsNode.create(parent_nodes=output_nodes) def build_plan_for_distinct_values( self, query_spec: MetricFlowQuerySpec, optimizations: FrozenSet[DataflowPlanOptimization] = frozenset() @@ -779,21 +783,21 @@ def _build_plan_for_distinct_values( output_node = dataflow_recipe.source_node if dataflow_recipe.join_targets: - output_node = JoinOnEntitiesNode(left_node=output_node, join_targets=dataflow_recipe.join_targets) + output_node = JoinOnEntitiesNode.create(left_node=output_node, join_targets=dataflow_recipe.join_targets) if len(query_level_filter_specs) > 0: - output_node = WhereConstraintNode(parent_node=output_node, where_specs=query_level_filter_specs) + output_node = WhereConstraintNode.create(parent_node=output_node, where_specs=query_level_filter_specs) if query_spec.time_range_constraint: - output_node = ConstrainTimeRangeNode( + output_node = ConstrainTimeRangeNode.create( parent_node=output_node, time_range_constraint=query_spec.time_range_constraint ) - output_node = FilterElementsNode( + output_node = FilterElementsNode.create( parent_node=output_node, include_specs=query_spec.linkable_specs.as_spec_set, distinct=True ) if query_spec.min_max_only: - output_node = MinMaxNode(parent_node=output_node) + output_node = MinMaxNode.create(parent_node=output_node) sink_node = self.build_sink_node( parent_node=output_node, order_by_specs=query_spec.order_by_specs, limit=query_spec.limit @@ -814,20 +818,20 @@ def build_sink_node( pre_result_node: Optional[DataflowPlanNode] = None if order_by_specs or limit: - pre_result_node = OrderByLimitNode( + pre_result_node = OrderByLimitNode.create( order_by_specs=list(order_by_specs), limit=limit, parent_node=parent_node ) if output_selection_specs: - pre_result_node = FilterElementsNode( + pre_result_node = FilterElementsNode.create( parent_node=pre_result_node or parent_node, include_specs=output_selection_specs ) write_result_node: DataflowPlanNode if not output_sql_table: - write_result_node = WriteToResultDataTableNode(parent_node=pre_result_node or parent_node) + write_result_node = WriteToResultDataTableNode.create(parent_node=pre_result_node or parent_node) else: - write_result_node = WriteToResultTableNode( + write_result_node = WriteToResultTableNode.create( parent_node=pre_result_node or parent_node, output_sql_table=output_sql_table ) @@ -1139,7 +1143,7 @@ def build_computed_metrics_node( for_group_by_source_node: bool = False, ) -> ComputeMetricsNode: """Builds a ComputeMetricsNode from aggregated measures.""" - return ComputeMetricsNode( + return ComputeMetricsNode.create( parent_node=aggregated_measures_node, metric_specs=[metric_spec], for_group_by_source_node=for_group_by_source_node, @@ -1449,7 +1453,7 @@ def _build_aggregated_measure_from_measure_source_node( # Otherwise, the measure will be aggregated over all time. time_range_node: Optional[JoinOverTimeRangeNode] = None if cumulative and queried_agg_time_dimension_specs: - time_range_node = JoinOverTimeRangeNode( + time_range_node = JoinOverTimeRangeNode.create( parent_node=measure_recipe.source_node, queried_agg_time_dimension_specs=tuple(queried_agg_time_dimension_specs), window=cumulative_window, @@ -1476,7 +1480,7 @@ def _build_aggregated_measure_from_measure_source_node( ) # This also uses the original time range constraint due to the application of the time window intervals # in join rendering - join_to_time_spine_node = JoinToTimeSpineNode( + join_to_time_spine_node = JoinToTimeSpineNode.create( parent_node=time_range_node or measure_recipe.source_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, @@ -1487,7 +1491,7 @@ def _build_aggregated_measure_from_measure_source_node( ) # Only get the required measure and the local linkable instances so that aggregations work correctly. - filtered_measure_source_node = FilterElementsNode( + filtered_measure_source_node = FilterElementsNode.create( parent_node=join_to_time_spine_node or time_range_node or measure_recipe.source_node, include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge( group_specs_by_type(measure_recipe.required_local_linkable_specs), @@ -1497,7 +1501,7 @@ def _build_aggregated_measure_from_measure_source_node( join_targets = measure_recipe.join_targets unaggregated_measure_node: DataflowPlanNode if len(join_targets) > 0: - filtered_measures_with_joined_elements = JoinOnEntitiesNode( + filtered_measures_with_joined_elements = JoinOnEntitiesNode.create( left_node=filtered_measure_source_node, join_targets=join_targets, ) @@ -1506,7 +1510,7 @@ def _build_aggregated_measure_from_measure_source_node( required_linkable_specs.as_spec_set, ) - after_join_filtered_node = FilterElementsNode( + after_join_filtered_node = FilterElementsNode.create( parent_node=filtered_measures_with_joined_elements, include_specs=specs_to_keep_after_join ) unaggregated_measure_node = after_join_filtered_node @@ -1524,14 +1528,14 @@ def _build_aggregated_measure_from_measure_source_node( assert ( queried_linkable_specs.contains_metric_time ), "Using time constraints currently requires querying with metric_time." - cumulative_metric_constrained_node = ConstrainTimeRangeNode( + cumulative_metric_constrained_node = ConstrainTimeRangeNode.create( unaggregated_measure_node, predicate_pushdown_state.time_range_constraint ) pre_aggregate_node: DataflowPlanNode = cumulative_metric_constrained_node or unaggregated_measure_node if len(metric_input_measure_spec.filter_specs) > 0: # Apply where constraint on the node - pre_aggregate_node = WhereConstraintNode( + pre_aggregate_node = WhereConstraintNode.create( parent_node=pre_aggregate_node, where_specs=metric_input_measure_spec.filter_specs, ) @@ -1550,7 +1554,7 @@ def _build_aggregated_measure_from_measure_source_node( window_groupings = tuple( LinklessEntitySpec.from_element_name(name) for name in non_additive_dimension_spec.window_groupings ) - pre_aggregate_node = SemiAdditiveJoinNode( + pre_aggregate_node = SemiAdditiveJoinNode.create( parent_node=pre_aggregate_node, entity_specs=window_groupings, time_dimension_spec=time_dimension_spec, @@ -1564,12 +1568,12 @@ def _build_aggregated_measure_from_measure_source_node( # show up in the final result. # # e.g. for "bookings" by "ds" where "is_instant", "is_instant" should not be in the results. - pre_aggregate_node = FilterElementsNode( + pre_aggregate_node = FilterElementsNode.create( parent_node=pre_aggregate_node, include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge(queried_linkable_specs.as_spec_set), ) - aggregate_measures_node = AggregateMeasuresNode( + aggregate_measures_node = AggregateMeasuresNode.create( parent_node=pre_aggregate_node, metric_input_measure_specs=(metric_input_measure_spec,), ) @@ -1583,7 +1587,7 @@ def _build_aggregated_measure_from_measure_source_node( f"Expected {SqlJoinType.LEFT_OUTER} for joining to time spine after aggregation. Remove this if " f"there's a new use case." ) - output_node: DataflowPlanNode = JoinToTimeSpineNode( + output_node: DataflowPlanNode = JoinToTimeSpineNode.create( parent_node=aggregate_measures_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, @@ -1602,14 +1606,14 @@ def _build_aggregated_measure_from_measure_source_node( if set(filter_spec.linkable_specs).issubset(set(queried_linkable_specs.as_tuple)) ] if len(queried_filter_specs) > 0: - output_node = WhereConstraintNode( + output_node = WhereConstraintNode.create( parent_node=output_node, where_specs=queried_filter_specs, always_apply=True ) # TODO: this will break if you query by agg_time_dimension but apply a time constraint on metric_time. # To fix when enabling time range constraints for agg_time_dimension. if queried_agg_time_dimension_specs and predicate_pushdown_state.time_range_constraint is not None: - output_node = ConstrainTimeRangeNode( + output_node = ConstrainTimeRangeNode.create( parent_node=output_node, time_range_constraint=predicate_pushdown_state.time_range_constraint ) return output_node diff --git a/metricflow/dataflow/builder/node_evaluator.py b/metricflow/dataflow/builder/node_evaluator.py index 262d26b854..971425eee6 100644 --- a/metricflow/dataflow/builder/node_evaluator.py +++ b/metricflow/dataflow/builder/node_evaluator.py @@ -122,7 +122,7 @@ def join_description(self) -> JoinDescription: ] ) - filtered_node_to_join = FilterElementsNode( + filtered_node_to_join = FilterElementsNode.create( parent_node=self.node_to_join, include_specs=group_specs_by_type(include_specs) ) diff --git a/metricflow/dataflow/builder/source_node.py b/metricflow/dataflow/builder/source_node.py index 7b23d58cbf..0054f57bca 100644 --- a/metricflow/dataflow/builder/source_node.py +++ b/metricflow/dataflow/builder/source_node.py @@ -59,8 +59,8 @@ def __init__( # noqa: D107 time_spine_source = TimeSpineSource.create_from_manifest(semantic_manifest_lookup.semantic_manifest) time_spine_data_set = data_set_converter.build_time_spine_source_data_set(time_spine_source) time_dim_reference = TimeDimensionReference(element_name=time_spine_source.time_column_name) - self._time_spine_source_node = MetricTimeDimensionTransformNode( - parent_node=ReadSqlSourceNode(data_set=time_spine_data_set), + self._time_spine_source_node = MetricTimeDimensionTransformNode.create( + parent_node=ReadSqlSourceNode.create(data_set=time_spine_data_set), aggregation_time_dimension_reference=time_dim_reference, ) self._query_parser = MetricFlowQueryParser(semantic_manifest_lookup) @@ -71,7 +71,7 @@ def create_from_data_sets(self, data_sets: Sequence[SemanticModelDataSet]) -> So source_nodes_for_metric_queries: List[DataflowPlanNode] = [] for data_set in data_sets: - read_node = ReadSqlSourceNode(data_set) + read_node = ReadSqlSourceNode.create(data_set) group_by_item_source_nodes.append(read_node) agg_time_dim_to_measures_grouper = ( self._semantic_manifest_lookup.semantic_model_lookup.get_aggregation_time_dimensions_with_measures( @@ -86,7 +86,7 @@ def create_from_data_sets(self, data_sets: Sequence[SemanticModelDataSet]) -> So else: # Splits the measures by distinct aggregate time dimension. for time_dimension_reference in time_dimension_references: - metric_time_transform_node = MetricTimeDimensionTransformNode( + metric_time_transform_node = MetricTimeDimensionTransformNode.create( parent_node=read_node, aggregation_time_dimension_reference=time_dimension_reference, ) diff --git a/metricflow/dataflow/dataflow_plan.py b/metricflow/dataflow/dataflow_plan.py index 6944a73132..ba8c0c0a8e 100644 --- a/metricflow/dataflow/dataflow_plan.py +++ b/metricflow/dataflow/dataflow_plan.py @@ -5,11 +5,12 @@ import logging import typing from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import FrozenSet, Generic, Optional, Sequence, Set, Type, TypeVar import more_itertools from metricflow_semantics.dag.id_prefix import StaticIdPrefix -from metricflow_semantics.dag.mf_dag import DagId, DagNode, MetricFlowDag, NodeId +from metricflow_semantics.dag.mf_dag import DagId, DagNode, MetricFlowDag from metricflow_semantics.visitor import Visitable, VisitorOutputT if typing.TYPE_CHECKING: @@ -42,28 +43,14 @@ NodeSelfT = TypeVar("NodeSelfT", bound="DataflowPlanNode") -class DataflowPlanNode(DagNode, Visitable, ABC): +@dataclass(frozen=True) +class DataflowPlanNode(DagNode["DataflowPlanNode"], Visitable, ABC): """A node in the graph representation of the dataflow. Each node in the graph performs an operation from the data that comes from the parent nodes, and the result is passed to the child nodes. The flow of data starts from source nodes, and ends at sink nodes. """ - def __init__(self, node_id: NodeId, parent_nodes: Sequence[DataflowPlanNode]) -> None: - """Constructor. - - Args: - node_id: the ID for the node - parent_nodes: data comes from the parent nodes. - """ - self._parent_nodes = tuple(parent_nodes) - super().__init__(node_id=node_id) - - @property - def parent_nodes(self) -> Sequence[DataflowPlanNode]: - """Return the nodes where data for this node comes from.""" - return self._parent_nodes - @property def _input_semantic_model(self) -> Optional[SemanticModelReference]: """Return the semantic model serving as direct input for this node, if one exists.""" diff --git a/metricflow/dataflow/nodes/add_generated_uuid.py b/metricflow/dataflow/nodes/add_generated_uuid.py index c3832819ab..6a5a1c2b9f 100644 --- a/metricflow/dataflow/nodes/add_generated_uuid.py +++ b/metricflow/dataflow/nodes/add_generated_uuid.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Sequence from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix @@ -9,11 +10,17 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor +@dataclass(frozen=True) class AddGeneratedUuidColumnNode(DataflowPlanNode): """Adds a UUID column.""" - def __init__(self, parent_node: DataflowPlanNode) -> None: # noqa: D107 - super().__init__(node_id=self.create_unique_id(), parent_nodes=[parent_node]) + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + + @staticmethod + def create(parent_node: DataflowPlanNode) -> AddGeneratedUuidColumnNode: # noqa: D102 + return AddGeneratedUuidColumnNode(parent_nodes=(parent_node,)) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -28,7 +35,6 @@ def description(self) -> str: # noqa: D102 @property def parent_node(self) -> DataflowPlanNode: # noqa: D102 - assert len(self.parent_nodes) == 1 return self.parent_nodes[0] @property @@ -42,4 +48,4 @@ def with_new_parents( # noqa: D102 self, new_parent_nodes: Sequence[DataflowPlanNode] ) -> AddGeneratedUuidColumnNode: assert len(new_parent_nodes) == 1 - return AddGeneratedUuidColumnNode(parent_node=new_parent_nodes[0]) + return AddGeneratedUuidColumnNode(parent_nodes=(new_parent_nodes[0],)) diff --git a/metricflow/dataflow/nodes/aggregate_measures.py b/metricflow/dataflow/nodes/aggregate_measures.py index 619986a093..757ba6ec8b 100644 --- a/metricflow/dataflow/nodes/aggregate_measures.py +++ b/metricflow/dataflow/nodes/aggregate_measures.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Sequence, Tuple from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix @@ -9,6 +10,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor +@dataclass(frozen=True) class AggregateMeasuresNode(DataflowPlanNode): """A node that aggregates the measures by the associated group by elements. @@ -16,23 +18,25 @@ class AggregateMeasuresNode(DataflowPlanNode): resulting from an operation on this node must apply the alias and transform the measure instances accordingly, otherwise this join could produce a query with two identically named measure columns with, e.g., different constraints applied to the measure. + + The input measure specs are required for downstream nodes to be aware of any input measures with + user-provided aliases, such as we might encounter with constrained and unconstrained versions of the + same input measure. """ - def __init__( - self, - parent_node: DataflowPlanNode, - metric_input_measure_specs: Sequence[MetricInputMeasureSpec], - ) -> None: - """Initializer for AggregateMeasuresNode. + metric_input_measure_specs: Tuple[MetricInputMeasureSpec, ...] - The input measure specs are required for downstream nodes to be aware of any input measures with - user-provided aliases, such as we might encounter with constrained and unconstrained versions of the - same input measure. - """ - self._parent_node = parent_node - self._metric_input_measure_specs = tuple(metric_input_measure_specs) + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 - super().__init__(node_id=self.create_unique_id(), parent_nodes=[self._parent_node]) + @staticmethod + def create( # noqa: D102 + parent_node: DataflowPlanNode, metric_input_measure_specs: Sequence[MetricInputMeasureSpec] + ) -> AggregateMeasuresNode: + return AggregateMeasuresNode( + parent_nodes=(parent_node,), metric_input_measure_specs=tuple(metric_input_measure_specs) + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -47,15 +51,7 @@ def description(self) -> str: # noqa: D102 @property def parent_node(self) -> DataflowPlanNode: # noqa: D102 - return self._parent_node - - @property - def metric_input_measure_specs(self) -> Tuple[MetricInputMeasureSpec, ...]: - """Iterable of specs for measure inputs to downstream metrics. - - Used for assigning aliases to output columns produced by aggregated measures. - """ - return self._metric_input_measure_specs + return self.parent_nodes[0] def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 return ( @@ -66,6 +62,6 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> AggregateMeasuresNode: # noqa: D102 assert len(new_parent_nodes) == 1 return AggregateMeasuresNode( - parent_node=new_parent_nodes[0], + parent_nodes=tuple(new_parent_nodes), metric_input_measure_specs=self.metric_input_measure_specs, ) diff --git a/metricflow/dataflow/nodes/combine_aggregated_outputs.py b/metricflow/dataflow/nodes/combine_aggregated_outputs.py index 437eaccc90..0f022ec8a2 100644 --- a/metricflow/dataflow/nodes/combine_aggregated_outputs.py +++ b/metricflow/dataflow/nodes/combine_aggregated_outputs.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Sequence from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix @@ -11,19 +12,21 @@ ) +@dataclass(frozen=True) class CombineAggregatedOutputsNode(DataflowPlanNode): """Combines metrics from different nodes into a single output.""" - def __init__( # noqa: D107 - self, - parent_nodes: Sequence[DataflowPlanNode], - ) -> None: - num_parents = len(parent_nodes) + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + num_parents = len(self.parent_nodes) assert num_parents > 1, ( "The CombineAggregatedOutputsNode is intended to merge the output datasets from 2 or more nodes, but this " f"node is being initialized with with only {num_parents} parent(s)." ) - super().__init__(node_id=self.create_unique_id(), parent_nodes=parent_nodes) + + @staticmethod + def create(parent_nodes: Sequence[DataflowPlanNode]) -> CombineAggregatedOutputsNode: # noqa: D102 + return CombineAggregatedOutputsNode(parent_nodes=tuple(parent_nodes)) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -42,4 +45,4 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: def with_new_parents( # noqa: D102 self, new_parent_nodes: Sequence[DataflowPlanNode] ) -> CombineAggregatedOutputsNode: - return CombineAggregatedOutputsNode(parent_nodes=new_parent_nodes) + return CombineAggregatedOutputsNode(parent_nodes=tuple(new_parent_nodes)) diff --git a/metricflow/dataflow/nodes/compute_metrics.py b/metricflow/dataflow/nodes/compute_metrics.py index 8461021445..4e079a27d6 100644 --- a/metricflow/dataflow/nodes/compute_metrics.py +++ b/metricflow/dataflow/nodes/compute_metrics.py @@ -1,11 +1,13 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Sequence, Set, Tuple from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DisplayedProperty from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec, MetricSpec from metricflow_semantics.visitor import VisitorOutputT +from typing_extensions import override from metricflow.dataflow.dataflow_plan import ( DataflowPlanNode, @@ -13,43 +15,41 @@ ) +@dataclass(frozen=True) class ComputeMetricsNode(DataflowPlanNode): - """A node that computes metrics from input measures. Dimensions / entities are passed through.""" + """A node that computes metrics from input measures. Dimensions / entities are passed through. - def __init__( - self, + Attributes: + metric_specs: The specs for the metrics that this should compute. + for_group_by_source_node: Whether the node is part of a dataflow plan used for a group by source node. + """ + + metric_specs: Tuple[MetricSpec, ...] + for_group_by_source_node: bool + _aggregated_to_elements: Tuple[LinkableInstanceSpec, ...] + + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + + @staticmethod + def create( # noqa: D102 parent_node: DataflowPlanNode, metric_specs: Sequence[MetricSpec], aggregated_to_elements: Set[LinkableInstanceSpec], for_group_by_source_node: bool = False, - ) -> None: - """Constructor. - - Args: - parent_node: Node where data is coming from. - metric_specs: The specs for the metrics that this should compute. - for_group_by_source_node: Whether the node is part of a dataflow plan used for a group by source node. - """ - self._parent_node = parent_node - self._metric_specs = tuple(metric_specs) - self._for_group_by_source_node = for_group_by_source_node - self._aggregated_to_elements = aggregated_to_elements - super().__init__(node_id=self.create_unique_id(), parent_nodes=(self._parent_node,)) + ) -> ComputeMetricsNode: + return ComputeMetricsNode( + parent_nodes=(parent_node,), + metric_specs=tuple(metric_specs), + _aggregated_to_elements=tuple(aggregated_to_elements), + for_group_by_source_node=for_group_by_source_node, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 return StaticIdPrefix.DATAFLOW_NODE_COMPUTE_METRICS_ID_PREFIX - @property - def for_group_by_source_node(self) -> bool: - """Whether or not this node is part of a dataflow plan used for a group by source node.""" - return self._for_group_by_source_node - - @property - def metric_specs(self) -> Sequence[MetricSpec]: - """The metric instances that this node is supposed to compute and should have in the output.""" - return self._metric_specs - def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_compute_metrics_node(self) @@ -60,7 +60,7 @@ def description(self) -> str: # noqa: D102 @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 displayed_properties = tuple(super().displayed_properties) + tuple( - DisplayedProperty("metric_spec", metric_spec) for metric_spec in self._metric_specs + DisplayedProperty("metric_spec", metric_spec) for metric_spec in self.metric_specs ) if self.for_group_by_source_node: displayed_properties += (DisplayedProperty("for_group_by_source_node", self.for_group_by_source_node),) @@ -68,7 +68,7 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 @property def parent_node(self) -> DataflowPlanNode: # noqa: D102 - return self._parent_node + return self.parent_nodes[0] def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 if not isinstance(other_node, self.__class__): @@ -100,13 +100,14 @@ def can_combine(self, other_node: ComputeMetricsNode) -> Tuple[bool, str]: def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> ComputeMetricsNode: # noqa: D102 assert len(new_parent_nodes) == 1 - return ComputeMetricsNode( + return ComputeMetricsNode.create( parent_node=new_parent_nodes[0], metric_specs=self.metric_specs, for_group_by_source_node=self.for_group_by_source_node, - aggregated_to_elements=self._aggregated_to_elements, + aggregated_to_elements=self.aggregated_to_elements, ) @property - def aggregated_to_elements(self) -> Set[LinkableInstanceSpec]: # noqa: D102 - return self._aggregated_to_elements + @override + def aggregated_to_elements(self) -> Set[LinkableInstanceSpec]: + return set(self._aggregated_to_elements) diff --git a/metricflow/dataflow/nodes/constrain_time.py b/metricflow/dataflow/nodes/constrain_time.py index 2fd840212f..7ca0ace50b 100644 --- a/metricflow/dataflow/nodes/constrain_time.py +++ b/metricflow/dataflow/nodes/constrain_time.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Sequence from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix @@ -11,6 +12,7 @@ from metricflow.dataflow.nodes.aggregate_measures import DataflowPlanNode +@dataclass(frozen=True) class ConstrainTimeRangeNode(DataflowPlanNode): """Constrains the time range of the input data set. @@ -18,13 +20,21 @@ class ConstrainTimeRangeNode(DataflowPlanNode): includes sales for a specific range of dates. """ - def __init__( # noqa: D107 - self, + time_range_constraint: TimeRangeConstraint + + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + + @staticmethod + def create( # noqa: D102 parent_node: DataflowPlanNode, time_range_constraint: TimeRangeConstraint, - ) -> None: - self._time_range_constraint = time_range_constraint - super().__init__(node_id=self.create_unique_id(), parent_nodes=(parent_node,)) + ) -> ConstrainTimeRangeNode: + return ConstrainTimeRangeNode( + parent_nodes=(parent_node,), + time_range_constraint=time_range_constraint, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -40,13 +50,8 @@ def description(self) -> str: # noqa: D102 f"{self.time_range_constraint.end_time.isoformat()}]" ) - @property - def time_range_constraint(self) -> TimeRangeConstraint: # noqa: D102 - return self._time_range_constraint - @property def parent_node(self) -> DataflowPlanNode: # noqa: D102 - assert len(self.parent_nodes) == 1 return self.parent_nodes[0] @property @@ -62,6 +67,6 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> ConstrainTimeRangeNode: # noqa: D102 assert len(new_parent_nodes) == 1 return ConstrainTimeRangeNode( - parent_node=new_parent_nodes[0], + parent_nodes=tuple(new_parent_nodes), time_range_constraint=self.time_range_constraint, ) diff --git a/metricflow/dataflow/nodes/filter_elements.py b/metricflow/dataflow/nodes/filter_elements.py index f250d1db52..e38ef4e0c2 100644 --- a/metricflow/dataflow/nodes/filter_elements.py +++ b/metricflow/dataflow/nodes/filter_elements.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Optional, Sequence, Tuple from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix @@ -11,60 +12,66 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor +@dataclass(frozen=True) class FilterElementsNode(DataflowPlanNode): - """Only passes the listed elements.""" + """Only passes the listed elements. - def __init__( # noqa: D107 - self, + Attributes: + include_specs: The specs for the elements that it should pass. + replace_description: Replace the default description with this. + distinct: If you only want the distinct values for the selected specs.. + """ + + include_specs: InstanceSpecSet + replace_description: Optional[str] = None + distinct: bool = False + + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + + @staticmethod + def create( # noqa: D102 parent_node: DataflowPlanNode, include_specs: InstanceSpecSet, replace_description: Optional[str] = None, distinct: bool = False, - ) -> None: - self._include_specs = include_specs - self._replace_description = replace_description - self._parent_node = parent_node - self._distinct = distinct - super().__init__(node_id=self.create_unique_id(), parent_nodes=(parent_node,)) + ) -> FilterElementsNode: + return FilterElementsNode( + parent_nodes=(parent_node,), + include_specs=include_specs, + replace_description=replace_description, + distinct=distinct, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 return StaticIdPrefix.DATAFLOW_NODE_PASS_FILTER_ELEMENTS_ID_PREFIX - @property - def include_specs(self) -> InstanceSpecSet: - """Returns the specs for the elements that it should pass.""" - return self._include_specs - - @property - def distinct(self) -> bool: - """True if you only want the distinct values for the selected specs.""" - return self._distinct - def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_filter_elements_node(self) @property def description(self) -> str: # noqa: D102 - if self._replace_description: - return self._replace_description + if self.replace_description: + return self.replace_description - return f"Pass Only Elements: {mf_pformat([x.qualified_name for x in self._include_specs.all_specs])}" + return f"Pass Only Elements: {mf_pformat([x.qualified_name for x in self.include_specs.all_specs])}" @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 additional_properties: Tuple[DisplayedProperty, ...] = () - if not self._replace_description: + if not self.replace_description: additional_properties = tuple( - DisplayedProperty("include_spec", include_spec) for include_spec in self._include_specs.all_specs + DisplayedProperty("include_spec", include_spec) for include_spec in self.include_specs.all_specs ) + ( - DisplayedProperty("distinct", self._distinct), + DisplayedProperty("distinct", self.distinct), ) return tuple(super().displayed_properties) + additional_properties @property def parent_node(self) -> DataflowPlanNode: # noqa: D102 - return self._parent_node + return self.parent_nodes[0] def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 return ( @@ -76,8 +83,8 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> FilterElementsNode: # noqa: D102 assert len(new_parent_nodes) == 1 return FilterElementsNode( - parent_node=new_parent_nodes[0], + parent_nodes=tuple(new_parent_nodes), include_specs=self.include_specs, distinct=self.distinct, - replace_description=self._replace_description, + replace_description=self.replace_description, ) diff --git a/metricflow/dataflow/nodes/join_conversion_events.py b/metricflow/dataflow/nodes/join_conversion_events.py index 62474ec8a3..029fa0dd1f 100644 --- a/metricflow/dataflow/nodes/join_conversion_events.py +++ b/metricflow/dataflow/nodes/join_conversion_events.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Optional, Sequence +from dataclasses import dataclass +from typing import Optional, Sequence, Tuple from dbt_semantic_interfaces.protocols import MetricTimeWindow from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix @@ -17,11 +18,35 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor +@dataclass(frozen=True) class JoinConversionEventsNode(DataflowPlanNode): - """Builds a data set containing successful conversion events.""" - - def __init__( - self, + """Builds a data set containing successful conversion events. + + Attributes: + base_node: node containing dataset for computing base events. + base_time_dimension_spec: time dimension for the base events to compute against. + conversion_node: node containing dataset to join base node for computing conversion events. + conversion_measure_spec: expose this measure in the resulting dataset for aggregation. + conversion_time_dimension_spec: time dimension for the conversion events to compute against. + unique_identifier_keys: columns to uniquely identify each conversion event. + entity_spec: the specific entity in which the conversion is happening for. + window: time range bound for when a conversion is still considered valid (default: INF). + constant_properties: optional set of elements (either dimension/entity) to join the base + event to the conversion event. + """ + + base_node: DataflowPlanNode + base_time_dimension_spec: TimeDimensionSpec + conversion_node: DataflowPlanNode + conversion_measure_spec: MeasureSpec + conversion_time_dimension_spec: TimeDimensionSpec + unique_identifier_keys: Tuple[InstanceSpec, ...] + entity_spec: EntitySpec + window: Optional[MetricTimeWindow] + constant_properties: Optional[Tuple[ConstantPropertySpec, ...]] + + @staticmethod + def create( # noqa: D102 base_node: DataflowPlanNode, base_time_dimension_spec: TimeDimensionSpec, conversion_node: DataflowPlanNode, @@ -31,31 +56,19 @@ def __init__( entity_spec: EntitySpec, window: Optional[MetricTimeWindow] = None, constant_properties: Optional[Sequence[ConstantPropertySpec]] = None, - ) -> None: - """Constructor. - - Args: - base_node: node containing dataset for computing base events. - base_time_dimension_spec: time dimension for the base events to compute against. - conversion_node: node containing dataset to join base node for computing conversion events. - conversion_measure_spec: expose this measure in the resulting dataset for aggregation. - conversion_time_dimension_spec: time dimension for the conversion events to compute against. - unique_identifier_keys: columns to uniquely identify each conversion event. - entity_spec: the specific entity in which the conversion is happening for. - window: time range bound for when a conversion is still considered valid (default: INF). - constant_properties: optional set of elements (either dimension/entity) to join the base - event to the conversion event. - """ - self._base_node = base_node - self._conversion_node = conversion_node - self._base_time_dimension_spec = base_time_dimension_spec - self._conversion_measure_spec = conversion_measure_spec - self._conversion_time_dimension_spec = conversion_time_dimension_spec - self._unique_identifier_keys = unique_identifier_keys - self._entity_spec = entity_spec - self._window = window - self._constant_properties = constant_properties - super().__init__(node_id=self.create_unique_id(), parent_nodes=(base_node, conversion_node)) + ) -> JoinConversionEventsNode: + return JoinConversionEventsNode( + parent_nodes=(base_node, conversion_node), + base_node=base_node, + base_time_dimension_spec=base_time_dimension_spec, + conversion_node=conversion_node, + conversion_measure_spec=conversion_measure_spec, + conversion_time_dimension_spec=conversion_time_dimension_spec, + unique_identifier_keys=tuple(unique_identifier_keys), + entity_spec=entity_spec, + window=window, + constant_properties=tuple(constant_properties) if constant_properties is not None else None, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -64,42 +77,6 @@ def id_prefix(cls) -> IdPrefix: # noqa: D102 def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_join_conversion_events_node(self) - @property - def base_node(self) -> DataflowPlanNode: # noqa: D102 - return self._base_node - - @property - def conversion_node(self) -> DataflowPlanNode: # noqa: D102 - return self._conversion_node - - @property - def conversion_measure_spec(self) -> MeasureSpec: # noqa: D102 - return self._conversion_measure_spec - - @property - def base_time_dimension_spec(self) -> TimeDimensionSpec: # noqa: D102 - return self._base_time_dimension_spec - - @property - def conversion_time_dimension_spec(self) -> TimeDimensionSpec: # noqa: D102 - return self._conversion_time_dimension_spec - - @property - def unique_identifier_keys(self) -> Sequence[InstanceSpec]: # noqa: D102 - return self._unique_identifier_keys - - @property - def entity_spec(self) -> EntitySpec: # noqa: D102 - return self._entity_spec - - @property - def window(self) -> Optional[MetricTimeWindow]: # noqa: D102 - return self._window - - @property - def constant_properties(self) -> Optional[Sequence[ConstantPropertySpec]]: # noqa: D102 - return self._constant_properties - @property def description(self) -> str: # noqa: D102 return f"Find conversions for {self.entity_spec.qualified_name} within the range of {f'{self.window.count} {self.window.granularity.value}' if self.window else 'INF'}" @@ -136,6 +113,7 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> JoinConversionEventsNode: # noqa: D102 assert len(new_parent_nodes) == 2 return JoinConversionEventsNode( + parent_nodes=tuple(new_parent_nodes), base_node=new_parent_nodes[0], base_time_dimension_spec=self.base_time_dimension_spec, conversion_node=new_parent_nodes[1], diff --git a/metricflow/dataflow/nodes/join_over_time.py b/metricflow/dataflow/nodes/join_over_time.py index bdef913556..8dd71c24bb 100644 --- a/metricflow/dataflow/nodes/join_over_time.py +++ b/metricflow/dataflow/nodes/join_over_time.py @@ -1,11 +1,12 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Optional, Sequence from dbt_semantic_interfaces.protocols import MetricTimeWindow from dbt_semantic_interfaces.type_enums import TimeGranularity from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix -from metricflow_semantics.dag.mf_dag import DisplayedProperty, NodeId +from metricflow_semantics.dag.mf_dag import DisplayedProperty from metricflow_semantics.filters.time_constraint import TimeRangeConstraint from metricflow_semantics.specs.spec_classes import TimeDimensionSpec from metricflow_semantics.visitor import VisitorOutputT @@ -13,43 +14,46 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor +@dataclass(frozen=True) class JoinOverTimeRangeNode(DataflowPlanNode): - """A node that allows for cumulative metric computation by doing a self join across a cumulative date range.""" - - def __init__( - self, + """A node that allows for cumulative metric computation by doing a self join across a cumulative date range. + + Attributes: + queried_agg_time_dimension_specs: Time dimension specs that will be selected from the time spine table. + window: Time window to join over. + grain_to_date: Indicates time range should start from the beginning of this time granularity (e.g., month to day). + time_range_constraint: Time range to aggregate over. + """ + + queried_agg_time_dimension_specs: Sequence[TimeDimensionSpec] + window: Optional[MetricTimeWindow] + grain_to_date: Optional[TimeGranularity] + time_range_constraint: Optional[TimeRangeConstraint] + + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + + @staticmethod + def create( # noqa: D102 parent_node: DataflowPlanNode, queried_agg_time_dimension_specs: Sequence[TimeDimensionSpec], - window: Optional[MetricTimeWindow], - grain_to_date: Optional[TimeGranularity], - node_id: Optional[NodeId] = None, + window: Optional[MetricTimeWindow] = None, + grain_to_date: Optional[TimeGranularity] = None, time_range_constraint: Optional[TimeRangeConstraint] = None, - ) -> None: - """Constructor. - - Args: - parent_node: node with standard output - window: time window to join over - grain_to_date: indicates time range should start from the beginning of this time granularity - (eg month to day) - node_id: Override the node ID with this value - time_range_constraint: time range to aggregate over - queried_agg_time_dimension_specs: time dimension specs that will be selected from time spine table - """ + ) -> JoinOverTimeRangeNode: if window and grain_to_date: raise RuntimeError( f"This node cannot be initialized with both window and grain_to_date set. This configuration should " f"have been prevented by model validation. window: {window}. grain_to_date: {grain_to_date}." ) - self._parent_node = parent_node - self._grain_to_date = grain_to_date - self._window = window - self.time_range_constraint = time_range_constraint - self.queried_agg_time_dimension_specs = queried_agg_time_dimension_specs - - # Doing a list comprehension throws a type error, so doing it this way. - parent_nodes: Sequence[DataflowPlanNode] = (self._parent_node,) - super().__init__(node_id=node_id or self.create_unique_id(), parent_nodes=parent_nodes) + return JoinOverTimeRangeNode( + parent_nodes=(parent_node,), + queried_agg_time_dimension_specs=queried_agg_time_dimension_specs, + window=window, + grain_to_date=grain_to_date, + time_range_constraint=time_range_constraint, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -58,21 +62,13 @@ def id_prefix(cls) -> IdPrefix: # noqa: D102 def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_join_over_time_range_node(self) - @property - def grain_to_date(self) -> Optional[TimeGranularity]: # noqa: D102 - return self._grain_to_date - @property def description(self) -> str: # noqa: D102 return """Join Self Over Time Range""" @property def parent_node(self) -> DataflowPlanNode: # noqa: D102 - return self._parent_node - - @property - def window(self) -> Optional[MetricTimeWindow]: # noqa: D102 - return self._window + return self.parent_nodes[0] @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 @@ -99,7 +95,7 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> JoinOverTimeRangeNode: # noqa: D102 assert len(new_parent_nodes) == 1 - return JoinOverTimeRangeNode( + return JoinOverTimeRangeNode.create( parent_node=new_parent_nodes[0], window=self.window, grain_to_date=self.grain_to_date, diff --git a/metricflow/dataflow/nodes/join_to_base.py b/metricflow/dataflow/nodes/join_to_base.py index 4894bcc459..eb7cd332d1 100644 --- a/metricflow/dataflow/nodes/join_to_base.py +++ b/metricflow/dataflow/nodes/join_to_base.py @@ -1,10 +1,10 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List, Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix -from metricflow_semantics.dag.mf_dag import DisplayedProperty, NodeId +from metricflow_semantics.dag.mf_dag import DisplayedProperty from metricflow_semantics.specs.spec_classes import LinklessEntitySpec, TimeDimensionSpec from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.visitor import VisitorOutputT @@ -42,30 +42,29 @@ def __post_init__(self) -> None: # noqa: D105 raise RuntimeError("`join_on_entity` is required unless using CROSS JOIN.") +@dataclass(frozen=True) class JoinOnEntitiesNode(DataflowPlanNode): - """A node that joins data from other nodes via the entities in the inputs.""" + """A node that joins data from other nodes via the entities in the inputs. + + Attributes: + left_node: Node with standard output. + join_targets: Other sources that should be joined to this node. + """ + + left_node: DataflowPlanNode + join_targets: Tuple[JoinDescription, ...] - def __init__( - self, + @staticmethod + def create( # noqa: D102 left_node: DataflowPlanNode, join_targets: Sequence[JoinDescription], - node_id: Optional[NodeId] = None, - ) -> None: - """Constructor. - - Args: - left_node: node with standard output - join_targets: other sources that should be joined to this node. - node_id: Override the node ID with this value - """ - self._left_node = left_node - self._join_targets = tuple(join_targets) - - # Doing a list comprehension throws a type error, so doing it this way. - parent_nodes: List[DataflowPlanNode] = [self._left_node] - for join_target in self._join_targets: - parent_nodes.append(join_target.join_node) - super().__init__(node_id=node_id or self.create_unique_id(), parent_nodes=parent_nodes) + ) -> JoinOnEntitiesNode: + parent_nodes = [left_node] + [join_target.join_node for join_target in join_targets] + return JoinOnEntitiesNode( + parent_nodes=tuple(parent_nodes), + left_node=left_node, + join_targets=tuple(join_targets), + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -78,19 +77,11 @@ def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOut def description(self) -> str: # noqa: D102 return """Join Standard Outputs""" - @property - def left_node(self) -> DataflowPlanNode: # noqa: D102 - return self._left_node - - @property - def join_targets(self) -> Sequence[JoinDescription]: # noqa: D102 - return self._join_targets - @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 return tuple(super().displayed_properties) + tuple( DisplayedProperty(f"join{i}_for_node_id_{join_description.join_node.node_id}", join_description) - for i, join_description in enumerate(self._join_targets) + for i, join_description in enumerate(self.join_targets) ) def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 @@ -113,9 +104,9 @@ def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> Join assert len(new_parent_nodes) > 1 new_left_node = new_parent_nodes[0] new_join_nodes = new_parent_nodes[1:] - assert len(new_join_nodes) == len(self._join_targets) + assert len(new_join_nodes) == len(self.join_targets) - return JoinOnEntitiesNode( + return JoinOnEntitiesNode.create( left_node=new_left_node, join_targets=[ JoinDescription( @@ -126,6 +117,6 @@ def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> Join validity_window=old_join_target.validity_window, join_type=old_join_target.join_type, ) - for i, old_join_target in enumerate(self._join_targets) + for i, old_join_target in enumerate(self.join_targets) ], ) diff --git a/metricflow/dataflow/nodes/join_to_time_spine.py b/metricflow/dataflow/nodes/join_to_time_spine.py index 0d8a9a298a..86f3104ed4 100644 --- a/metricflow/dataflow/nodes/join_to_time_spine.py +++ b/metricflow/dataflow/nodes/join_to_time_spine.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC +from dataclasses import dataclass from typing import Optional, Sequence from dbt_semantic_interfaces.protocols import MetricTimeWindow @@ -15,11 +16,39 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor +@dataclass(frozen=True) class JoinToTimeSpineNode(DataflowPlanNode, ABC): - """Join parent dataset to time spine dataset.""" + """Join parent dataset to time spine dataset. + + Attributes: + requested_agg_time_dimension_specs: Time dimensions requested in the query. + use_custom_agg_time_dimension: Indicates if agg_time_dimension should be used in join. If false, uses metric_time. + join_type: Join type to use when joining to time spine. + time_range_constraint: Time range to constrain the time spine to. + offset_window: Time window to offset the parent dataset by when joining to time spine. + offset_to_grain: Granularity period to offset the parent dataset to when joining to time spine. + """ + + requested_agg_time_dimension_specs: Sequence[TimeDimensionSpec] + use_custom_agg_time_dimension: bool + join_type: SqlJoinType + time_range_constraint: Optional[TimeRangeConstraint] + offset_window: Optional[MetricTimeWindow] + offset_to_grain: Optional[TimeGranularity] + + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 - def __init__( - self, + assert not ( + self.offset_window and self.offset_to_grain + ), "Can't set both offset_window and offset_to_grain when joining to time spine. Choose one or the other." + assert ( + len(self.requested_agg_time_dimension_specs) > 0 + ), "Must have at least one value in requested_agg_time_dimension_specs for JoinToTimeSpineNode." + + @staticmethod + def create( # noqa: D102 parent_node: DataflowPlanNode, requested_agg_time_dimension_specs: Sequence[TimeDimensionSpec], use_custom_agg_time_dimension: bool, @@ -27,70 +56,21 @@ def __init__( time_range_constraint: Optional[TimeRangeConstraint] = None, offset_window: Optional[MetricTimeWindow] = None, offset_to_grain: Optional[TimeGranularity] = None, - ) -> None: - """Constructor. - - Args: - parent_node: Node that returns desired dataset to join to time spine. - requested_agg_time_dimension_specs: Time dimensions requested in query. - use_custom_agg_time_dimension: Indicates if agg_time_dimension should be used in join. If false, uses metric_time. - time_range_constraint: Time range to constrain the time spine to. - offset_window: Time window to offset the parent dataset by when joining to time spine. - offset_to_grain: Granularity period to offset the parent dataset to when joining to time spine. - - Passing both offset_window and offset_to_grain not allowed. - """ - assert not ( - offset_window and offset_to_grain - ), "Can't set both offset_window and offset_to_grain when joining to time spine. Choose one or the other." - assert ( - len(requested_agg_time_dimension_specs) > 0 - ), "Must have at least one value in requested_agg_time_dimension_specs for JoinToTimeSpineNode." - - self._parent_node = parent_node - self._requested_agg_time_dimension_specs = tuple(requested_agg_time_dimension_specs) - self._use_custom_agg_time_dimension = use_custom_agg_time_dimension - self._offset_window = offset_window - self._offset_to_grain = offset_to_grain - self._time_range_constraint = time_range_constraint - self._join_type = join_type - - super().__init__(node_id=self.create_unique_id(), parent_nodes=(self._parent_node,)) + ) -> JoinToTimeSpineNode: + return JoinToTimeSpineNode( + parent_nodes=(parent_node,), + requested_agg_time_dimension_specs=tuple(requested_agg_time_dimension_specs), + use_custom_agg_time_dimension=use_custom_agg_time_dimension, + join_type=join_type, + time_range_constraint=time_range_constraint, + offset_window=offset_window, + offset_to_grain=offset_to_grain, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 return StaticIdPrefix.DATAFLOW_NODE_JOIN_TO_TIME_SPINE_ID_PREFIX - @property - def requested_agg_time_dimension_specs(self) -> Sequence[TimeDimensionSpec]: - """Time dimension specs to use when creating time spine table.""" - return self._requested_agg_time_dimension_specs - - @property - def use_custom_agg_time_dimension(self) -> bool: - """Whether or not metric_time was included in the query.""" - return self._use_custom_agg_time_dimension - - @property - def time_range_constraint(self) -> Optional[TimeRangeConstraint]: - """Time range constraint to apply when querying time spine table.""" - return self._time_range_constraint - - @property - def offset_window(self) -> Optional[MetricTimeWindow]: - """Time range constraint to apply when querying time spine table.""" - return self._offset_window - - @property - def offset_to_grain(self) -> Optional[TimeGranularity]: - """Time range constraint to apply when querying time spine table.""" - return self._offset_to_grain - - @property - def join_type(self) -> SqlJoinType: - """Join type to use when joining to time spine.""" - return self._join_type - def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_join_to_time_spine_node(self) @@ -101,17 +81,17 @@ def description(self) -> str: # noqa: D102 @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 return tuple(super().displayed_properties) + ( - DisplayedProperty("requested_agg_time_dimension_specs", self._requested_agg_time_dimension_specs), - DisplayedProperty("use_custom_agg_time_dimension", self._use_custom_agg_time_dimension), - DisplayedProperty("time_range_constraint", self._time_range_constraint), - DisplayedProperty("offset_window", self._offset_window), - DisplayedProperty("offset_to_grain", self._offset_to_grain), - DisplayedProperty("join_type", self._join_type), + DisplayedProperty("requested_agg_time_dimension_specs", self.requested_agg_time_dimension_specs), + DisplayedProperty("use_custom_agg_time_dimension", self.use_custom_agg_time_dimension), + DisplayedProperty("time_range_constraint", self.time_range_constraint), + DisplayedProperty("offset_window", self.offset_window), + DisplayedProperty("offset_to_grain", self.offset_to_grain), + DisplayedProperty("join_type", self.join_type), ) @property def parent_node(self) -> DataflowPlanNode: # noqa: D102 - return self._parent_node + return self.parent_nodes[0] def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 return ( @@ -126,7 +106,7 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> JoinToTimeSpineNode: # noqa: D102 assert len(new_parent_nodes) == 1 - return JoinToTimeSpineNode( + return JoinToTimeSpineNode.create( parent_node=new_parent_nodes[0], requested_agg_time_dimension_specs=self.requested_agg_time_dimension_specs, use_custom_agg_time_dimension=self.use_custom_agg_time_dimension, diff --git a/metricflow/dataflow/nodes/metric_time_transform.py b/metricflow/dataflow/nodes/metric_time_transform.py index 1dd5389f16..47e5df2ffd 100644 --- a/metricflow/dataflow/nodes/metric_time_transform.py +++ b/metricflow/dataflow/nodes/metric_time_transform.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Sequence from dbt_semantic_interfaces.references import TimeDimensionReference @@ -10,6 +11,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor +@dataclass(frozen=True) class MetricTimeDimensionTransformNode(DataflowPlanNode): """A node transforms the input data set so that it contains the metric time dimension and relevant measures. @@ -19,32 +21,41 @@ class MetricTimeDimensionTransformNode(DataflowPlanNode): Output: a data set similar to the input data set, but includes the configured aggregation time dimension as the metric time dimension and only contains measures that are defined to use it. + + Attributes: + aggregation_time_dimension_reference: The time dimension that measures in the input should be aggregated to. """ - def __init__( # noqa: D107 - self, + aggregation_time_dimension_reference: TimeDimensionReference + + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + + @staticmethod + def create( # noqa: D102 parent_node: DataflowPlanNode, aggregation_time_dimension_reference: TimeDimensionReference, - ) -> None: - self._aggregation_time_dimension_reference = aggregation_time_dimension_reference - self._parent_node = parent_node - super().__init__(node_id=self.create_unique_id(), parent_nodes=[parent_node]) + ) -> MetricTimeDimensionTransformNode: + return MetricTimeDimensionTransformNode( + parent_nodes=(parent_node,), + aggregation_time_dimension_reference=aggregation_time_dimension_reference, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 return StaticIdPrefix.DATAFLOW_NODE_SET_MEASURE_AGGREGATION_TIME - @property - def aggregation_time_dimension_reference(self) -> TimeDimensionReference: - """The time dimension that measures in the input should be aggregated to.""" - return self._aggregation_time_dimension_reference - def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_metric_time_dimension_transform_node(self) @property def description(self) -> str: # noqa: D102 - return f"Metric Time Dimension '{self.aggregation_time_dimension_reference.element_name}'" "" + return f"Metric Time Dimension '{self.aggregation_time_dimension_reference.element_name}'" + + @property + def parent_node(self) -> DataflowPlanNode: # noqa: D102 + return self.parent_nodes[0] @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 @@ -52,10 +63,6 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 DisplayedProperty("aggregation_time_dimension", self.aggregation_time_dimension_reference.element_name), ) - @property - def parent_node(self) -> DataflowPlanNode: # noqa: D102 - return self._parent_node - def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 return ( isinstance(other_node, self.__class__) @@ -66,7 +73,7 @@ def with_new_parents( # noqa: D102 self, new_parent_nodes: Sequence[DataflowPlanNode] ) -> MetricTimeDimensionTransformNode: # noqa: D102 assert len(new_parent_nodes) == 1 - return MetricTimeDimensionTransformNode( + return MetricTimeDimensionTransformNode.create( parent_node=new_parent_nodes[0], aggregation_time_dimension_reference=self.aggregation_time_dimension_reference, ) diff --git a/metricflow/dataflow/nodes/min_max.py b/metricflow/dataflow/nodes/min_max.py index 1f6268b76c..40fa160739 100644 --- a/metricflow/dataflow/nodes/min_max.py +++ b/metricflow/dataflow/nodes/min_max.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Sequence from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix @@ -8,12 +9,17 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor +@dataclass(frozen=True) class MinMaxNode(DataflowPlanNode): """Calculate the min and max of a single instance data set.""" - def __init__(self, parent_node: DataflowPlanNode) -> None: # noqa: D107 - self._parent_node = parent_node - super().__init__(node_id=self.create_unique_id(), parent_nodes=(parent_node,)) + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + + @staticmethod + def create(parent_node: DataflowPlanNode) -> MinMaxNode: # noqa: D102 + return MinMaxNode(parent_nodes=(parent_node,)) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -28,11 +34,11 @@ def description(self) -> str: # noqa: D102 @property def parent_node(self) -> DataflowPlanNode: # noqa: D102 - return self._parent_node + return self.parent_nodes[0] def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 return isinstance(other_node, self.__class__) def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> MinMaxNode: # noqa: D102 assert len(new_parent_nodes) == 1 - return MinMaxNode(parent_node=new_parent_nodes[0]) + return MinMaxNode.create(parent_node=new_parent_nodes[0]) diff --git a/metricflow/dataflow/nodes/order_by_limit.py b/metricflow/dataflow/nodes/order_by_limit.py index 7319618954..45eef8cddf 100644 --- a/metricflow/dataflow/nodes/order_by_limit.py +++ b/metricflow/dataflow/nodes/order_by_limit.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Optional, Sequence, Union +from dataclasses import dataclass +from typing import Optional, Sequence from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DisplayedProperty @@ -13,61 +14,58 @@ ) +@dataclass(frozen=True) class OrderByLimitNode(DataflowPlanNode): - """A node that re-orders the input data with a limit.""" + """A node that re-orders the input data with a limit. - def __init__( - self, + Attributes: + order_by_specs: Describes how to order the incoming data. + limit: Number of rows to limit. + """ + + order_by_specs: Sequence[OrderBySpec] + limit: Optional[int] + + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + + @staticmethod + def create( # noqa: D102 order_by_specs: Sequence[OrderBySpec], - parent_node: Union[DataflowPlanNode, DataflowPlanNode], + parent_node: DataflowPlanNode, limit: Optional[int] = None, - ) -> None: - """Constructor. - - Args: - order_by_specs: describes how to order the incoming data. - limit: number of rows to limit. - parent_node: self-explanatory. - """ - self._order_by_specs = tuple(order_by_specs) - self._limit = limit - self._parent_node = parent_node - super().__init__(node_id=self.create_unique_id(), parent_nodes=(self._parent_node,)) + ) -> OrderByLimitNode: + return OrderByLimitNode( + parent_nodes=(parent_node,), + order_by_specs=tuple(order_by_specs), + limit=limit, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 return StaticIdPrefix.DATAFLOW_NODE_ORDER_BY_LIMIT_ID_PREFIX - @property - def order_by_specs(self) -> Sequence[OrderBySpec]: - """The elements that this node should order the input data.""" - return self._order_by_specs - - @property - def limit(self) -> Optional[int]: - """The number of rows to limit by.""" - return self._limit - def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_order_by_limit_node(self) @property def description(self) -> str: # noqa: D102 - return f"Order By {[order_by_spec.instance_spec.qualified_name for order_by_spec in self._order_by_specs]}" + ( - f" Limit {self._limit}" if self.limit else "" + return f"Order By {[order_by_spec.instance_spec.qualified_name for order_by_spec in self.order_by_specs]}" + ( + f" Limit {self.limit}" if self.limit else "" ) @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 return ( tuple(super().displayed_properties) - + tuple(DisplayedProperty("order_by_spec", order_by_spec) for order_by_spec in self._order_by_specs) + + tuple(DisplayedProperty("order_by_spec", order_by_spec) for order_by_spec in self.order_by_specs) + (DisplayedProperty("limit", str(self.limit)),) ) @property - def parent_node(self) -> Union[DataflowPlanNode, DataflowPlanNode]: # noqa: D102 - return self._parent_node + def parent_node(self) -> DataflowPlanNode: # noqa: D102 + return self.parent_nodes[0] def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 return ( @@ -79,7 +77,7 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> OrderByLimitNode: # noqa: D102 assert len(new_parent_nodes) == 1 - return OrderByLimitNode( + return OrderByLimitNode.create( parent_node=new_parent_nodes[0], order_by_specs=self.order_by_specs, limit=self.limit, diff --git a/metricflow/dataflow/nodes/read_sql_source.py b/metricflow/dataflow/nodes/read_sql_source.py index 0010a77661..de1da2f604 100644 --- a/metricflow/dataflow/nodes/read_sql_source.py +++ b/metricflow/dataflow/nodes/read_sql_source.py @@ -1,6 +1,7 @@ from __future__ import annotations import textwrap +from dataclasses import dataclass from typing import Optional, Sequence import jinja2 @@ -14,17 +15,28 @@ from metricflow.dataset.sql_dataset import SqlDataSet +@dataclass(frozen=True) class ReadSqlSourceNode(DataflowPlanNode): - """A source node where data from an SQL table or SQL query is read and output.""" + """A source node where data from an SQL table or SQL query is read and output. - def __init__(self, data_set: SqlDataSet) -> None: - """Constructor. + Attributes: + data_set: Dataset describing the SQL table / SQL query. + """ - Args: - data_set: dataset describing the SQL table / SQL query - """ - self._dataset = data_set - super().__init__(node_id=self.create_unique_id(), parent_nodes=()) + data_set: SqlDataSet + + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 0 + + @staticmethod + def create( # noqa: D102 + data_set: SqlDataSet, + ) -> ReadSqlSourceNode: + return ReadSqlSourceNode( + parent_nodes=(), + data_set=data_set, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -39,11 +51,6 @@ def _input_semantic_model(self) -> Optional[SemanticModelReference]: """Return the semantic model serving as direct input for this node, if one exists.""" return self.data_set.semantic_model_reference - @property - def data_set(self) -> SqlDataSet: - """Return the data set that this source represents and is passed to the child nodes.""" - return self._dataset - def __str__(self) -> str: # noqa: D105 return jinja2.Template( textwrap.dedent( @@ -66,4 +73,4 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> ReadSqlSourceNode: # noqa: D102 assert len(new_parent_nodes) == 0 - return ReadSqlSourceNode(data_set=self.data_set) + return ReadSqlSourceNode.create(data_set=self.data_set) diff --git a/metricflow/dataflow/nodes/semi_additive_join.py b/metricflow/dataflow/nodes/semi_additive_join.py index f68d2e3765..b48f10602e 100644 --- a/metricflow/dataflow/nodes/semi_additive_join.py +++ b/metricflow/dataflow/nodes/semi_additive_join.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Optional, Sequence from dbt_semantic_interfaces.type_enums import AggregationType @@ -11,6 +12,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor +@dataclass(frozen=True) class SemiAdditiveJoinNode(DataflowPlanNode): """A node that performs a row filter by aggregating a given non-additive dimension. @@ -24,9 +26,9 @@ class SemiAdditiveJoinNode(DataflowPlanNode): Data transformation example, | date | account_balance | user | | date | account_balance | user | - |:-----------|-----------------|-----:| entity_specs: |:-----------|-----------------|-----:| + |:-----------|-----------------|-----:| entity_specs: |:-----------|-----------------|-----:| | 2019-12-31 | 1000 | u1| - user | 2020-01-03 | 2000 | u1| - | 2020-01-03 | 2000 | u1| -> time_dimension_spec: -> | 2020-01-12 1500 | u2| + | 2020-01-03 | 2000 | u1| -> time_dimension_spec: -> | 2020-01-12 | 1500 | u2| | 2020-01-09 | 3000 | u2| - date | 2020-01-12 | 1000 | u3| | 2020-01-12 | 1500 | u2| agg_by_function: | 2020-01-12 | 1000 | u3| - MAX @@ -37,7 +39,7 @@ class SemiAdditiveJoinNode(DataflowPlanNode): Data transformation example, | date | account_balance | user | | date | account_balance | - |:-----------|-----------------|-----:| entity_specs: |:-----------|----------------:| + |:-----------|-----------------|-----:| entity_specs: |:-----------|----------------:| | 2019-12-31 | 1000 | u1| | 2020-01-12 | 2500 | | 2020-01-03 | 2000 | u1| -> time_dimension_spec: -> | 2020-01-09 | 3000 | u2| - date @@ -50,7 +52,7 @@ class SemiAdditiveJoinNode(DataflowPlanNode): Data transformation example, | date | account_balance | user | | date | account_balance | - |:-----------|-----------------|-----:| entity_specs: |:-----------|----------------:| + |:-----------|-----------------|-----:| entity_specs: |:-----------|----------------:| | 2019-12-31 | 1500 | u1| time_dimension_spec: | 2019-12-31 | 1500 | | 2020-01-03 | 2000 | u1| -> - date -> | 2020-01-07 | 3000 | | 2020-01-09 | 3000 | u2| agg_by_function: | 2020-01-14 | 3250 | @@ -58,34 +60,38 @@ class SemiAdditiveJoinNode(DataflowPlanNode): | 2020-01-14 | 1250 | u3| queried_time_dimension_spec: | 2020-01-14 | 2000 | u2| - date__week | 2020-01-15 | 4000 | u1| + + Attributes: + entity_specs: The entities to group the join by. + time_dimension_spec: The time dimension used for row filtering via an aggregation. + agg_by_function: The aggregation function used on the time dimension. + queried_time_dimension_spec: The group by provided in the query used to build the windows we want to filter on. """ - def __init__( - self, + entity_specs: Sequence[LinklessEntitySpec] + time_dimension_spec: TimeDimensionSpec + agg_by_function: AggregationType + queried_time_dimension_spec: Optional[TimeDimensionSpec] + + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + + @staticmethod + def create( # noqa: D102 parent_node: DataflowPlanNode, entity_specs: Sequence[LinklessEntitySpec], time_dimension_spec: TimeDimensionSpec, agg_by_function: AggregationType, queried_time_dimension_spec: Optional[TimeDimensionSpec] = None, - ) -> None: - """Constructor. - - Args: - parent_node: node with standard output - entity_specs: the entities to group the join by - time_dimension_spec: the time dimension used for row filtering via an aggregation - agg_by_function: the aggregation function used on the time dimension - queried_time_dimension_spec: The group by provided in the query used to build the windows we want to filter on. - """ - self._parent_node = parent_node - self._entity_specs = tuple(entity_specs) - self._time_dimension_spec = time_dimension_spec - self._agg_by_function = agg_by_function - self._queried_time_dimension_spec = queried_time_dimension_spec - - # Doing a list comprehension throws a type error, so doing it this way. - parent_nodes: Sequence[DataflowPlanNode] = (self._parent_node,) - super().__init__(node_id=self.create_unique_id(), parent_nodes=parent_nodes) + ) -> SemiAdditiveJoinNode: + return SemiAdditiveJoinNode( + parent_nodes=(parent_node,), + entity_specs=tuple(entity_specs), + time_dimension_spec=time_dimension_spec, + agg_by_function=agg_by_function, + queried_time_dimension_spec=queried_time_dimension_spec, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -100,23 +106,7 @@ def description(self) -> str: # noqa: D102 @property def parent_node(self) -> DataflowPlanNode: # noqa: D102 - return self._parent_node - - @property - def entity_specs(self) -> Sequence[LinklessEntitySpec]: # noqa: D102 - return self._entity_specs - - @property - def time_dimension_spec(self) -> TimeDimensionSpec: # noqa: D102 - return self._time_dimension_spec - - @property - def agg_by_function(self) -> AggregationType: # noqa: D102 - return self._agg_by_function - - @property - def queried_time_dimension_spec(self) -> Optional[TimeDimensionSpec]: # noqa: D102 - return self._queried_time_dimension_spec + return self.parent_nodes[0] @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 @@ -127,8 +117,7 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: return False return ( - isinstance(other_node, self.__class__) - and other_node.entity_specs == self.entity_specs + other_node.entity_specs == self.entity_specs and other_node.time_dimension_spec == self.time_dimension_spec and other_node.agg_by_function == self.agg_by_function and other_node.queried_time_dimension_spec == self.queried_time_dimension_spec @@ -137,7 +126,7 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> SemiAdditiveJoinNode: # noqa: D102 assert len(new_parent_nodes) == 1 - return SemiAdditiveJoinNode( + return SemiAdditiveJoinNode.create( parent_node=new_parent_nodes[0], entity_specs=self.entity_specs, time_dimension_spec=self.time_dimension_spec, diff --git a/metricflow/dataflow/nodes/where_filter.py b/metricflow/dataflow/nodes/where_filter.py index 04240a9029..fd3f1a527a 100644 --- a/metricflow/dataflow/nodes/where_filter.py +++ b/metricflow/dataflow/nodes/where_filter.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Sequence from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix @@ -10,31 +11,33 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor +@dataclass(frozen=True) class WhereConstraintNode(DataflowPlanNode): - """Remove rows using a WHERE clause.""" + """Remove rows using a WHERE clause. - def __init__( - self, + Attributes: + where_specs: Specifications for the WHERE clause to filter rows. + always_apply: Indicator if the WHERE clause should always be applied. + """ + + where_specs: Sequence[WhereFilterSpec] + always_apply: bool + + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + + @staticmethod + def create( # noqa: D102 parent_node: DataflowPlanNode, where_specs: Sequence[WhereFilterSpec], always_apply: bool = False, - ) -> None: - """Initializer. - - WhereConstraintNodes must always have exactly one parent, since they always wrap a single subquery input. - - The always_apply parameter serves as an indicator for a WhereConstraintNode that is added to a plan in order - to clean up null outputs from a pre-join filter. For example, when doing time spine joins to fill null values - for metric outputs sometimes that join will result in rows with null values for various dimension attributes. - By re-applying the filter expression after the join step we will discard those unexpected output rows created - by the join (rather than the underlying inputs). In this case, we must ensure that the filters defined in this - node are always applied at the moment this node is processed, regardless of whether or not they've been pushed - down through the DAG. - """ - self._where_specs = where_specs - self.parent_node = parent_node - self.always_apply = always_apply - super().__init__(node_id=self.create_unique_id(), parent_nodes=(parent_node,)) + ) -> WhereConstraintNode: + return WhereConstraintNode( + parent_nodes=(parent_node,), + where_specs=where_specs, + always_apply=always_apply, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -43,7 +46,7 @@ def id_prefix(cls) -> IdPrefix: # noqa: D102 @property def where(self) -> WhereFilterSpec: """Returns the specs for the elements that it should pass.""" - return WhereFilterSpec.merge_iterable(self._where_specs) + return WhereFilterSpec.merge_iterable(self.where_specs) @property def input_where_specs(self) -> Sequence[WhereFilterSpec]: @@ -53,7 +56,7 @@ def input_where_specs(self) -> Sequence[WhereFilterSpec]: for pushdown operations on the filter spec level. We merge them for things like rendering and node comparisons, but in some cases we may be able to push down a subset of the input specs for efficiency reasons. """ - return self._where_specs + return self.where_specs def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_where_constraint_node(self) @@ -64,6 +67,10 @@ def description(self) -> str: # noqa: D102 # e.g. "Constrain Output with WHERE listing__country = :1" return "Constrain Output with WHERE" + @property + def parent_node(self) -> DataflowPlanNode: # noqa: D102 + return self.parent_nodes[0] + @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 properties = tuple(super().displayed_properties) + (DisplayedProperty("where_condition", self.where),) @@ -80,6 +87,8 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> WhereConstraintNode: # noqa: D102 assert len(new_parent_nodes) == 1 - return WhereConstraintNode( - parent_node=new_parent_nodes[0], where_specs=self.input_where_specs, always_apply=self.always_apply + return WhereConstraintNode.create( + parent_node=new_parent_nodes[0], + where_specs=self.input_where_specs, + always_apply=self.always_apply, ) diff --git a/metricflow/dataflow/nodes/window_reaggregation_node.py b/metricflow/dataflow/nodes/window_reaggregation_node.py index c713fa904c..325bd666b8 100644 --- a/metricflow/dataflow/nodes/window_reaggregation_node.py +++ b/metricflow/dataflow/nodes/window_reaggregation_node.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Sequence, Set from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix @@ -14,32 +15,45 @@ from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode +@dataclass(frozen=True) class WindowReaggregationNode(DataflowPlanNode): """A node that re-aggregates metrics using window functions. Currently used for calculating cumulative metrics at various granularities. + + Attributes: + metric_spec: Specification of the metric to be re-aggregated. + order_by_spec: Specification of the time dimension to order by. + partition_by_specs: Specifications of the instances to partition by. """ - def __init__( # noqa: D107 - self, + metric_spec: MetricSpec + order_by_spec: TimeDimensionSpec + partition_by_specs: Sequence[InstanceSpec] + + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + if self.order_by_spec in self.partition_by_specs: + raise ValueError( + "Order by spec found in partition by specs for WindowAggregationNode. This indicates internal misconfiguration" + f" because reaggregation should not be needed in this circumstance. Order by spec: {self.order_by_spec}; " + f"Partition by specs:{self.partition_by_specs}" + ) + + @staticmethod + def create( # noqa: D102 parent_node: ComputeMetricsNode, metric_spec: MetricSpec, order_by_spec: TimeDimensionSpec, partition_by_specs: Sequence[InstanceSpec], - ) -> None: - if order_by_spec in partition_by_specs: - raise ValueError( - "Order by spec found in parition by specs for WindowAggregationNode. This indicates internal misconfiguration" - f" because reaggregation should not be needed in this circumstance. Order by spec: {order_by_spec}; " - f"Partition by specs:{partition_by_specs}" - ) - - self.parent_node = parent_node - self.metric_spec = metric_spec - self.order_by_spec = order_by_spec - self.partition_by_specs = partition_by_specs - - super().__init__(node_id=self.create_unique_id(), parent_nodes=(self.parent_node,)) + ) -> WindowReaggregationNode: + return WindowReaggregationNode( + parent_nodes=(parent_node,), + metric_spec=metric_spec, + order_by_spec=order_by_spec, + partition_by_specs=partition_by_specs, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -52,6 +66,10 @@ def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOut def description(self) -> str: # noqa: D102 return """Re-aggregate Metrics via Window Functions""" + @property + def parent_node(self) -> DataflowPlanNode: # noqa: D102 + return self.parent_nodes[0] + @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 return tuple(super().displayed_properties) + ( @@ -63,7 +81,7 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 return ( isinstance(other_node, self.__class__) - and other_node.parent_node == self.parent_node + and other_node.parent_nodes == self.parent_nodes and other_node.metric_spec == self.metric_spec and other_node.order_by_spec == self.order_by_spec and other_node.partition_by_specs == self.partition_by_specs @@ -75,7 +93,7 @@ def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> Wind assert isinstance( new_parent_node, ComputeMetricsNode ), "WindowReaggregationNode can only have ComputeMetricsNode as parent node." - return WindowReaggregationNode( + return WindowReaggregationNode.create( parent_node=new_parent_node, metric_spec=self.metric_spec, order_by_spec=self.order_by_spec, diff --git a/metricflow/dataflow/nodes/write_to_data_table.py b/metricflow/dataflow/nodes/write_to_data_table.py index 3bd9d3e664..39f6eb0fb0 100644 --- a/metricflow/dataflow/nodes/write_to_data_table.py +++ b/metricflow/dataflow/nodes/write_to_data_table.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Sequence from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix @@ -11,12 +12,21 @@ ) +@dataclass(frozen=True) class WriteToResultDataTableNode(DataflowPlanNode): """A node where incoming data gets written to a data_table.""" - def __init__(self, parent_node: DataflowPlanNode) -> None: # noqa: D107 - self._parent_node = parent_node - super().__init__(node_id=self.create_unique_id(), parent_nodes=(parent_node,)) + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + + @staticmethod + def create( # noqa: D102 + parent_node: DataflowPlanNode, + ) -> WriteToResultDataTableNode: + return WriteToResultDataTableNode( + parent_nodes=(parent_node,), + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -31,8 +41,7 @@ def description(self) -> str: # noqa: D102 @property def parent_node(self) -> DataflowPlanNode: # noqa: D102 - assert len(self.parent_nodes) == 1 - return self._parent_node + return self.parent_nodes[0] def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 return isinstance(other_node, self.__class__) @@ -41,4 +50,4 @@ def with_new_parents( # noqa: D102 self, new_parent_nodes: Sequence[DataflowPlanNode] ) -> WriteToResultDataTableNode: assert len(new_parent_nodes) == 1 - return WriteToResultDataTableNode(parent_node=new_parent_nodes[0]) + return WriteToResultDataTableNode.create(parent_node=new_parent_nodes[0]) diff --git a/metricflow/dataflow/nodes/write_to_table.py b/metricflow/dataflow/nodes/write_to_table.py index b513754fee..d27d974a20 100644 --- a/metricflow/dataflow/nodes/write_to_table.py +++ b/metricflow/dataflow/nodes/write_to_table.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Sequence from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix @@ -12,23 +13,29 @@ from metricflow.sql.sql_table import SqlTable +@dataclass(frozen=True) class WriteToResultTableNode(DataflowPlanNode): - """A node where incoming data gets written to a table.""" + """A node where incoming data gets written to a table. - def __init__( - self, + Attributes: + output_sql_table: The table where the computed metrics should be written to. + """ + + output_sql_table: SqlTable + + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + + @staticmethod + def create( # noqa: D102 parent_node: DataflowPlanNode, output_sql_table: SqlTable, - ) -> None: - """Constructor. - - Args: - parent_node: node that outputs the computed metrics. - output_sql_table: the table where the computed metrics should be written to. - """ - self._parent_node = parent_node - self._output_sql_table = output_sql_table - super().__init__(node_id=self.create_unique_id(), parent_nodes=(parent_node,)) + ) -> WriteToResultTableNode: + return WriteToResultTableNode( + parent_nodes=(parent_node,), + output_sql_table=output_sql_table, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -43,18 +50,13 @@ def description(self) -> str: # noqa: D102 @property def parent_node(self) -> DataflowPlanNode: # noqa: D102 - assert len(self.parent_nodes) == 1 - return self._parent_node - - @property - def output_sql_table(self) -> SqlTable: # noqa: D102 - return self._output_sql_table + return self.parent_nodes[0] def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 return isinstance(other_node, self.__class__) and other_node.output_sql_table == self.output_sql_table def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> WriteToResultTableNode: # noqa: D102 - return WriteToResultTableNode( + return WriteToResultTableNode.create( parent_node=new_parent_nodes[0], output_sql_table=self.output_sql_table, ) diff --git a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py index aed67dea7e..41ad2c2acb 100644 --- a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py +++ b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py @@ -312,7 +312,7 @@ def _push_down_where_filters( optimized_node = self._default_handler(node=node, pushdown_state=updated_pushdown_state) if len(filters_to_apply) > 0: return OptimizeBranchResult( - optimized_branch=WhereConstraintNode( + optimized_branch=WhereConstraintNode.create( parent_node=optimized_node.optimized_branch, where_specs=filters_to_apply ) ) @@ -397,7 +397,7 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran ) elif len(filter_specs_to_apply) > 0: optimized_node = OptimizeBranchResult( - optimized_branch=WhereConstraintNode( + optimized_branch=WhereConstraintNode.create( parent_node=optimized_parent.optimized_branch, where_specs=filter_specs_to_apply ) ) diff --git a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py index 1124f00498..732b41452c 100644 --- a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py +++ b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py @@ -259,7 +259,7 @@ def visit_aggregate_measures_node( # noqa: D102 ) return ComputeMetricsBranchCombinerResult() - combined_node = AggregateMeasuresNode( + combined_node = AggregateMeasuresNode.create( parent_node=combined_parent_node, metric_input_measure_specs=combined_metric_input_measure_specs, ) @@ -305,7 +305,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> ComputeMetrics if metric_spec not in unique_metric_specs: unique_metric_specs.append(metric_spec) - combined_node = ComputeMetricsNode( + combined_node = ComputeMetricsNode.create( parent_node=combined_parent_node, metric_specs=unique_metric_specs, aggregated_to_elements=current_right_node.aggregated_to_elements, @@ -389,7 +389,7 @@ def visit_filter_elements_node(self, node: FilterElementsNode) -> ComputeMetrics # De-dupe so that we don't see the same spec twice in include specs. For example, this can happen with dimension # specs since any branch that is merged together needs to output the same set of dimensions. - combined_node = FilterElementsNode( + combined_node = FilterElementsNode.create( parent_node=combined_parent_node, include_specs=self._current_left_node.include_specs.merge(current_right_node.include_specs).dedupe(), ) diff --git a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py index e7f87aa992..5fa9fc4602 100644 --- a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py +++ b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py @@ -148,7 +148,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> OptimizeBranch optimized_parent_result: OptimizeBranchResult = node.parent_node.accept(self) if optimized_parent_result.optimized_branch is not None: return OptimizeBranchResult( - optimized_branch=ComputeMetricsNode( + optimized_branch=ComputeMetricsNode.create( parent_node=optimized_parent_result.optimized_branch, metric_specs=node.metric_specs, for_group_by_source_node=node.for_group_by_source_node, @@ -264,7 +264,7 @@ def visit_combine_aggregated_outputs_node( # noqa: D102 return OptimizeBranchResult(optimized_branch=combined_parent_branches[0]) return OptimizeBranchResult( - optimized_branch=CombineAggregatedOutputsNode(parent_nodes=combined_parent_branches) + optimized_branch=CombineAggregatedOutputsNode.create(parent_nodes=combined_parent_branches) ) def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> OptimizeBranchResult: # noqa: D102 diff --git a/metricflow/dataset/convert_semantic_model.py b/metricflow/dataset/convert_semantic_model.py index 670e768827..2c76394129 100644 --- a/metricflow/dataset/convert_semantic_model.py +++ b/metricflow/dataset/convert_semantic_model.py @@ -169,15 +169,15 @@ def _make_element_sql_expr( "FALSE", "NULL", ): - return SqlColumnReferenceExpression( + return SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=table_alias, column_name=element_expr, ) ) - return SqlStringExpression(sql_expr=element_expr) + return SqlStringExpression.create(sql_expr=element_expr) - return SqlColumnReferenceExpression( + return SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=table_alias, column_name=element_name, @@ -368,7 +368,7 @@ def _build_time_dimension_instances_and_columns( select_columns.append( SqlSelectColumn( - expr=SqlExtractExpression(date_part=date_part, arg=dimension_select_expr), + expr=SqlExtractExpression.create(date_part=date_part, arg=dimension_select_expr), column_alias=time_dimension_instance.associated_column.column_name, ) ) @@ -379,7 +379,7 @@ def _build_column_for_time_granularity( self, time_granularity: TimeGranularity, expr: SqlExpressionNode, column_alias: str ) -> SqlSelectColumn: return SqlSelectColumn( - expr=SqlDateTruncExpression(time_granularity=time_granularity, arg=expr), column_alias=column_alias + expr=SqlDateTruncExpression.create(time_granularity=time_granularity, arg=expr), column_alias=column_alias ) def _create_entity_instances( @@ -493,9 +493,11 @@ def create_sql_source_data_set(self, semantic_model: SemanticModel) -> SemanticM all_select_columns.extend(select_columns) # Generate the "from" clause depending on whether it's an SQL query or an SQL table. - from_source = SqlTableFromClauseNode(sql_table=SqlTable.from_string(semantic_model.node_relation.relation_name)) + from_source = SqlTableFromClauseNode.create( + sql_table=SqlTable.from_string(semantic_model.node_relation.relation_name) + ) - select_statement_node = SqlSelectStatementNode( + select_statement_node = SqlSelectStatementNode.create( description=f"Read Elements From Semantic Model '{semantic_model.name}'", select_columns=tuple(all_select_columns), from_source=from_source, @@ -549,10 +551,10 @@ def build_time_spine_source_data_set(self, time_spine_source: TimeSpineSource) - return SqlDataSet( instance_set=InstanceSet(time_dimension_instances=tuple(time_dimension_instances)), - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=TIME_SPINE_DATA_SET_DESCRIPTION, select_columns=tuple(select_columns), - from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table), + from_source=SqlTableFromClauseNode.create(sql_table=time_spine_source.spine_table), from_source_alias=from_source_alias, ), ) diff --git a/metricflow/execution/dataflow_to_execution.py b/metricflow/execution/dataflow_to_execution.py index b553590465..4df6cc0f21 100644 --- a/metricflow/execution/dataflow_to_execution.py +++ b/metricflow/execution/dataflow_to_execution.py @@ -33,6 +33,7 @@ ExecutionPlan, SelectSqlQueryToDataTableTask, SelectSqlQueryToTableTask, + SqlQuery, ) from metricflow.plan_conversion.convert_to_sql_plan import ConvertToSqlPlanResult from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter @@ -80,10 +81,9 @@ def visit_write_to_result_data_table_node(self, node: WriteToResultDataTableNode render_sql_result = self._render_sql(convert_to_sql_plan_result) execution_plan = ExecutionPlan( leaf_tasks=( - SelectSqlQueryToDataTableTask( + SelectSqlQueryToDataTableTask.create( sql_client=self._sql_client, - sql_query=render_sql_result.sql, - bind_parameters=render_sql_result.bind_parameters, + sql_query=SqlQuery(render_sql_result.sql, render_sql_result.bind_parameters), ), ) ) @@ -99,10 +99,12 @@ def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> Conv render_sql_result = self._render_sql(convert_to_sql_plan_result) execution_plan = ExecutionPlan( leaf_tasks=( - SelectSqlQueryToTableTask( + SelectSqlQueryToTableTask.create( sql_client=self._sql_client, - sql_query=render_sql_result.sql, - bind_parameters=render_sql_result.bind_parameters, + sql_query=SqlQuery( + sql_query=render_sql_result.sql, + bind_parameters=render_sql_result.bind_parameters, + ), output_table=node.output_sql_table, ), ), diff --git a/metricflow/execution/execution_plan.py b/metricflow/execution/execution_plan.py index 5f3e8e6f18..a117d73803 100644 --- a/metricflow/execution/execution_plan.py +++ b/metricflow/execution/execution_plan.py @@ -18,48 +18,35 @@ logger = logging.getLogger(__name__) -class ExecutionPlanTask(DagNode, Visitable, ABC): +@dataclass(frozen=True) +class ExecutionPlanTask(DagNode["ExecutionPlanTask"], Visitable, ABC): """A node (aka task) in the DAG representation of the execution plan. In the DAG, a node's parents represent the tasks that need to be run before the node can run. Using the term task for these nodes as it seems more intuitive. - """ - - def __init__(self, task_id: NodeId, parent_nodes: List[ExecutionPlanTask]) -> None: - """Constructor. - Args: - task_id: the ID for the node - parent_nodes: the nodes that should be executed before this one. - """ - self._parent_nodes = parent_nodes - super().__init__(node_id=task_id) + Attributes: + sql_query: If this runs a SQL query, return the associated SQL. + """ - @property - def parent_nodes(self) -> Sequence[ExecutionPlanTask]: - """Return the nodes that should execute before this one.""" - return self._parent_nodes + sql_query: Optional[SqlQuery] @abstractmethod def execute(self) -> TaskExecutionResult: """Execute the actions of this node.""" + raise NotImplementedError @property def task_id(self) -> NodeId: """Alias for node ID since the nodes represent a task.""" return self.node_id - @property - @abstractmethod - def sql_query(self) -> Optional[SqlQuery]: - """If this runs a SQL query, return the associated SQL.""" - pass - @dataclass(frozen=True) class SqlQuery: """A SQL query that can be run along with bind parameters.""" + # This field will be renamed as it is confusing given the class name. sql_query: str bind_parameters: SqlBindParameters @@ -86,20 +73,30 @@ class TaskExecutionResult: df: Optional[MetricFlowDataTable] = None +@dataclass(frozen=True) class SelectSqlQueryToDataTableTask(ExecutionPlanTask): - """A task that runs a SELECT and puts that result into a data_table.""" + """A task that runs a SELECT and puts that result into a data_table. + + Attributes: + sql_client: The SQL client used to run the query. + sql_query: The SQL query to run. + parent_nodes: The parent tasks for this execution plan task. + """ - def __init__( # noqa: D107 - self, + sql_client: SqlClient + parent_nodes: Tuple[ExecutionPlanTask, ...] + + @staticmethod + def create( # noqa: D102 sql_client: SqlClient, - sql_query: str, - bind_parameters: SqlBindParameters, - parent_nodes: Optional[List[ExecutionPlanTask]] = None, - ) -> None: - self._sql_client = sql_client - self._sql_query = sql_query - self._bind_parameters = bind_parameters - super().__init__(task_id=self.create_unique_id(), parent_nodes=parent_nodes or []) + sql_query: SqlQuery, + parent_nodes: Sequence[ExecutionPlanTask] = (), + ) -> SelectSqlQueryToDataTableTask: + return SelectSqlQueryToDataTableTask( + sql_client=sql_client, + sql_query=sql_query, + parent_nodes=tuple(parent_nodes), + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -111,55 +108,61 @@ def description(self) -> str: # noqa: D102 @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 - return tuple(super().displayed_properties) + (DisplayedProperty(key="sql_query", value=self._sql_query),) - - @property - def bind_parameters(self) -> SqlBindParameters: # noqa: D102 - return self._bind_parameters + sql_query = self.sql_query + assert sql_query is not None, f"{self.sql_query=} should have been set during creation." + return tuple(super().displayed_properties) + (DisplayedProperty(key="sql_query", value=sql_query.sql_query),) def execute(self) -> TaskExecutionResult: # noqa: D102 start_time = time.time() + sql_query = self.sql_query + assert sql_query is not None, f"{self.sql_query=} should have been set during creation." - df = self._sql_client.query( - self._sql_query, - sql_bind_parameters=self.bind_parameters, + df = self.sql_client.query( + sql_query.sql_query, + sql_bind_parameters=sql_query.bind_parameters, ) end_time = time.time() return TaskExecutionResult( - start_time=start_time, end_time=end_time, sql=self._sql_query, bind_params=self.bind_parameters, df=df - ) - - @property - def sql_query(self) -> Optional[SqlQuery]: # noqa: D102 - return SqlQuery( - sql_query=self._sql_query, - bind_parameters=self._bind_parameters, + start_time=start_time, + end_time=end_time, + sql=sql_query.sql_query, + bind_params=sql_query.bind_parameters, + df=df, ) def __repr__(self) -> str: # noqa: D105 - return f"{self.__class__.__name__}(sql_query='{self._sql_query}')" + return f"{self.__class__.__name__}(sql_query='{self.sql_query}')" +@dataclass(frozen=True) class SelectSqlQueryToTableTask(ExecutionPlanTask): """A task that runs a SELECT and puts that result into a table. The provided SQL query is the query that will be run, so it should be a CREATE... or similar. + + Attributes: + sql_client: The SQL client used to run the query. + sql_query: The SQL query to run. + output_table: The table where the results will be written. """ - def __init__( # noqa: D107 - self, + sql_client: SqlClient + output_table: SqlTable + + @staticmethod + def create( # noqa: D102 sql_client: SqlClient, - sql_query: str, - bind_parameters: SqlBindParameters, + sql_query: SqlQuery, output_table: SqlTable, - parent_nodes: Optional[List[ExecutionPlanTask]] = None, - ) -> None: - self._sql_client = sql_client - self._sql_query = sql_query - self._output_table = output_table - self._bind_parameters = bind_parameters - super().__init__(task_id=self.create_unique_id(), parent_nodes=parent_nodes or []) + parent_nodes: Sequence[ExecutionPlanTask] = (), + ) -> SelectSqlQueryToTableTask: + return SelectSqlQueryToTableTask( + sql_client=sql_client, + sql_query=sql_query, + output_table=output_table, + parent_nodes=tuple(parent_nodes), + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -171,31 +174,31 @@ def description(self) -> str: # noqa: D102 @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 + sql_query = self.sql_query + assert sql_query is not None, f"{self.sql_query=} should have been set during creation." return tuple(super().displayed_properties) + ( - DisplayedProperty(key="sql_query", value=self._sql_query), - DisplayedProperty(key="output_table", value=self._output_table), - DisplayedProperty(key="bind_parameters", value=self._bind_parameters), + DisplayedProperty(key="sql_query", value=sql_query.sql_query), + DisplayedProperty(key="output_table", value=self.output_table), + DisplayedProperty(key="bind_parameters", value=sql_query.bind_parameters), ) def execute(self) -> TaskExecutionResult: # noqa: D102 + sql_query = self.sql_query + assert sql_query is not None, f"{self.sql_query=} should have been set during creation." start_time = time.time() - logger.info(f"Dropping table {self._output_table} in case it already exists") - self._sql_client.execute(f"DROP TABLE IF EXISTS {self._output_table.sql}") - logger.info(f"Creating table {self._output_table} using a query") - self._sql_client.execute( - self._sql_query, - sql_bind_parameters=self._bind_parameters, + logger.info(f"Dropping table {self.output_table} in case it already exists") + self.sql_client.execute(f"DROP TABLE IF EXISTS {self.output_table.sql}") + logger.info(f"Creating table {self.output_table} using a query") + self.sql_client.execute( + sql_query.sql_query, + sql_bind_parameters=sql_query.bind_parameters, ) end_time = time.time() - return TaskExecutionResult(start_time=start_time, end_time=end_time, sql=self._sql_query) - - @property - def sql_query(self) -> Optional[SqlQuery]: # noqa: D102 - return SqlQuery(sql_query=self._sql_query, bind_parameters=self._bind_parameters) + return TaskExecutionResult(start_time=start_time, end_time=end_time, sql=sql_query.sql_query) def __repr__(self) -> str: # noqa: D105 - return f"{self.__class__.__name__}(sql_query='{self._sql_query}', output_table={self._output_table})" + return f"{self.__class__.__name__}(sql_query='{self.sql_query}', output_table={self.output_table})" class ExecutionPlan(MetricFlowDag[ExecutionPlanTask]): diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index c3ad1d542f..ad2d989989 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -150,17 +150,17 @@ def _make_time_range_comparison_expr( ) -> SqlExpressionNode: """Build an expression like "ds BETWEEN CAST('2020-01-01' AS TIMESTAMP) AND CAST('2020-01-02' AS TIMESTAMP).""" # TODO: Update when adding < day granularity support. - return SqlBetweenExpression( - column_arg=SqlColumnReferenceExpression( + return SqlBetweenExpression.create( + column_arg=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=table_alias, column_name=column_alias, ) ), - start_expr=SqlStringLiteralExpression( + start_expr=SqlStringLiteralExpression.create( literal_value=time_range_constraint.start_time.strftime(ISO8601_PYTHON_FORMAT), ), - end_expr=SqlStringLiteralExpression( + end_expr=SqlStringLiteralExpression.create( literal_value=time_range_constraint.end_time.strftime(ISO8601_PYTHON_FORMAT), ), ) @@ -254,7 +254,7 @@ def _make_time_spine_data_set( else: select_columns += ( SqlSelectColumn( - expr=SqlDateTruncExpression( + expr=SqlDateTruncExpression.create( time_granularity=agg_time_dimension_instance.spec.time_granularity, arg=column_expr ), column_alias=column_alias, @@ -264,10 +264,10 @@ def _make_time_spine_data_set( return SqlDataSet( instance_set=time_spine_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=TIME_SPINE_DATA_SET_DESCRIPTION, select_columns=select_columns, - from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table), + from_source=SqlTableFromClauseNode.create(sql_table=time_spine_source.spine_table), from_source_alias=time_spine_table_alias, group_bys=select_columns if apply_group_by else (), where=( @@ -353,14 +353,14 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat ) return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=create_select_columns_for_instance_sets( self._column_association_resolver, table_alias_to_instance_set ), from_source=time_spine_data_set.checked_sql_select_node, from_source_alias=time_spine_data_set_alias, - joins_descs=(join_desc,), + join_descs=(join_desc,), ), ) @@ -443,14 +443,14 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> SqlDataSet: # clauses. return SqlDataSet( instance_set=InstanceSet.merge(list(table_alias_to_instance_set.values())), - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=create_select_columns_for_instance_sets( self._column_association_resolver, table_alias_to_instance_set ), from_source=from_data_set.checked_sql_select_node, from_source_alias=from_data_set_alias, - joins_descs=tuple(sql_join_descs), + join_descs=tuple(sql_join_descs), ), ) @@ -518,7 +518,7 @@ def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> SqlDataS return SqlDataSet( instance_set=aggregated_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, # This will generate expressions with the appropriate aggregation functions e.g. SUM() select_columns=select_column_set.as_tuple(), @@ -576,14 +576,14 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet: MetricSpec.from_reference(denominator.post_aggregation_reference) ).column_name - metric_expr = SqlRatioComputationExpression( - numerator=SqlColumnReferenceExpression( + metric_expr = SqlRatioComputationExpression.create( + numerator=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=from_data_set_alias, column_name=numerator_column_name, ) ), - denominator=SqlColumnReferenceExpression( + denominator=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=from_data_set_alias, column_name=denominator_column_name, @@ -619,7 +619,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet: assert ( metric.type_params.expr ), "Derived metrics are required to have an `expr` in their YAML definition." - metric_expr = SqlStringExpression(sql_expr=metric.type_params.expr) + metric_expr = SqlStringExpression.create(sql_expr=metric.type_params.expr) elif metric.type == MetricType.CONVERSION: conversion_type_params = metric.type_params.conversion_type_params assert ( @@ -635,20 +635,20 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet: ).column_name calculation_type = conversion_type_params.calculation - conversion_column_reference = SqlColumnReferenceExpression( + conversion_column_reference = SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=from_data_set_alias, column_name=conversion_measure_column, ) ) - base_column_reference = SqlColumnReferenceExpression( + base_column_reference = SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=from_data_set_alias, column_name=base_measure_column, ) ) if calculation_type == ConversionCalculationType.CONVERSION_RATE: - metric_expr = SqlRatioComputationExpression( + metric_expr = SqlRatioComputationExpression.create( numerator=conversion_column_reference, denominator=base_column_reference, ) @@ -699,7 +699,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet: return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=combined_select_column_set.as_tuple(), from_source=from_data_set.checked_sql_select_node, @@ -711,14 +711,14 @@ def __make_col_reference_or_coalesce_expr( self, column_name: str, input_measure: Optional[MetricInputMeasure], from_data_set_alias: str ) -> SqlExpressionNode: # Use a column reference to improve query optimization. - metric_expr: SqlExpressionNode = SqlColumnReferenceExpression( + metric_expr: SqlExpressionNode = SqlColumnReferenceExpression.create( SqlColumnReference(table_alias=from_data_set_alias, column_name=column_name) ) # Coalesce nulls to requested integer value, if requested. if input_measure and input_measure.fill_nulls_with is not None: - metric_expr = SqlAggregateFunctionExpression( + metric_expr = SqlAggregateFunctionExpression.create( sql_function=SqlFunction.COALESCE, - sql_function_args=[metric_expr, SqlStringExpression(str(input_measure.fill_nulls_with))], + sql_function_args=[metric_expr, SqlStringExpression.create(str(input_measure.fill_nulls_with))], ) return metric_expr @@ -734,7 +734,7 @@ def visit_order_by_limit_node(self, node: OrderByLimitNode) -> SqlDataSet: # no for order_by_spec in node.order_by_specs: order_by_descriptions.append( SqlOrderByDescription( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias=from_data_set_alias, column_name=self._column_association_resolver.resolve_spec( @@ -748,7 +748,7 @@ def visit_order_by_limit_node(self, node: OrderByLimitNode) -> SqlDataSet: # no return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, # This creates select expressions for all columns referenced in the instance set. select_columns=output_instance_set.transform( @@ -770,7 +770,7 @@ def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> SqlD input_instance_set: InstanceSet = input_data_set.instance_set return SqlDataSet( instance_set=input_instance_set, - sql_node=SqlCreateTableAsNode( + sql_node=SqlCreateTableAsNode.create( sql_table=node.output_sql_table, parent_node=input_data_set.checked_sql_select_node, ), @@ -794,7 +794,7 @@ def visit_filter_elements_node(self, node: FilterElementsNode) -> SqlDataSet: group_bys = select_columns if node.distinct else () return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=select_columns, from_source=from_data_set.checked_sql_select_node, @@ -819,7 +819,7 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> SqlDataSet: return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, # This creates select expressions for all columns referenced in the instance set. select_columns=output_instance_set.transform( @@ -827,7 +827,7 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> SqlDataSet: ).as_tuple(), from_source=parent_data_set.checked_sql_select_node, from_source_alias=from_data_set_alias, - where=SqlStringExpression( + where=SqlStringExpression.create( sql_expr=node.where.where_sql, used_columns=tuple( column_association.column_name for column_association in column_associations_in_where_sql @@ -940,12 +940,12 @@ def visit_combine_aggregated_outputs_node(self, node: CombineAggregatedOutputsNo return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=combined_select_column_set.as_tuple(), from_source=from_data_set.data_set.checked_sql_select_node, from_source_alias=from_data_set.alias, - joins_descs=tuple(joins_descriptions), + join_descs=tuple(joins_descriptions), group_bys=linkable_select_column_set.as_tuple(), ), ) @@ -987,7 +987,7 @@ def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> SqlDa return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, # This creates select expressions for all columns referenced in the instance set. select_columns=output_instance_set.transform( @@ -1073,7 +1073,7 @@ def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTr return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, # This creates select expressions for all columns referenced in the instance set. select_columns=CreateSelectColumnsForInstances( @@ -1116,7 +1116,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe time_dimension_select_column = SqlSelectColumn( expr=SqlFunctionExpression.build_expression_from_aggregation_type( aggregation_type=node.agg_by_function, - sql_column_expression=SqlColumnReferenceExpression( + sql_column_expression=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=inner_join_data_set_alias, column_name=time_dimension_column_name, @@ -1138,7 +1138,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe entity_column_name = self.column_association_resolver.resolve_spec(entity_spec).column_name entity_select_columns.append( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=inner_join_data_set_alias, column_name=entity_column_name, @@ -1161,7 +1161,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe node.queried_time_dimension_spec ).column_name queried_time_dimension_select_column = SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=inner_join_data_set_alias, column_name=query_time_dimension_column_name, @@ -1174,7 +1174,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe if queried_time_dimension_select_column: row_filter_group_bys += (queried_time_dimension_select_column,) # Construct SelectNode for Row filtering - row_filter_sql_select_node = SqlSelectStatementNode( + row_filter_sql_select_node = SqlSelectStatementNode.create( description=f"Filter row on {node.agg_by_function.name}({time_dimension_column_name})", select_columns=row_filter_group_bys + (time_dimension_select_column,), from_source=from_data_set.checked_sql_select_node, @@ -1192,14 +1192,14 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe ) return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=output_instance_set.transform( CreateSelectColumnsForInstances(from_data_set_alias, self._column_association_resolver) ).as_tuple(), from_source=from_data_set.checked_sql_select_node, from_source_alias=from_data_set_alias, - joins_descs=(sql_join_desc,), + join_descs=(sql_join_desc,), ), ) @@ -1292,7 +1292,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet original_time_spine_dim_instance = time_spine_dataset.instance_set.time_dimension_instances[0] time_spine_column_select_expr: Union[ SqlColumnReferenceExpression, SqlDateTruncExpression - ] = SqlColumnReferenceExpression( + ] = SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=time_spine_alias, column_name=original_time_spine_dim_instance.spec.qualified_name ) @@ -1329,25 +1329,25 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet select_expr: SqlExpressionNode = ( time_spine_column_select_expr if time_dimension_spec.time_granularity == original_time_spine_dim_instance.spec.time_granularity - else SqlDateTruncExpression( + else SqlDateTruncExpression.create( time_granularity=time_dimension_spec.time_granularity, arg=time_spine_column_select_expr ) ) # Filter down to one row per granularity period requested in the group by. Any other granularities # included here will be filtered out in later nodes so should not be included in where filter. if need_where_filter and time_dimension_spec in node.requested_agg_time_dimension_specs: - new_where_filter = SqlComparisonExpression( + new_where_filter = SqlComparisonExpression.create( left_expr=select_expr, comparison=SqlComparison.EQUALS, right_expr=time_spine_column_select_expr ) where_filter = ( - SqlLogicalExpression(operator=SqlLogicalOperator.OR, args=(where_filter, new_where_filter)) + SqlLogicalExpression.create(operator=SqlLogicalOperator.OR, args=(where_filter, new_where_filter)) if where_filter else new_where_filter ) # Apply date_part to time spine column select expression. if time_dimension_spec.date_part: - select_expr = SqlExtractExpression(date_part=time_dimension_spec.date_part, arg=select_expr) + select_expr = SqlExtractExpression.create(date_part=time_dimension_spec.date_part, arg=select_expr) time_dim_spec = TimeDimensionSpec( element_name=original_time_spine_dim_instance.spec.element_name, entity_links=original_time_spine_dim_instance.spec.entity_links, @@ -1368,12 +1368,12 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet return SqlDataSet( instance_set=InstanceSet.merge([time_spine_instance_set, parent_instance_set]), - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=tuple(time_spine_select_columns) + parent_select_columns, from_source=time_spine_dataset.checked_sql_select_node, from_source_alias=time_spine_alias, - joins_descs=(join_description,), + join_descs=(join_description,), where=where_filter, ), ) @@ -1395,7 +1395,7 @@ def visit_min_max_node(self, node: MinMaxNode) -> SqlDataSet: # noqa: D102 SqlSelectColumn( expr=SqlFunctionExpression.build_expression_from_aggregation_type( aggregation_type=agg_type, - sql_column_expression=SqlColumnReferenceExpression( + sql_column_expression=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias=parent_table_alias, column_name=parent_column_alias) ), ), @@ -1408,7 +1408,7 @@ def visit_min_max_node(self, node: MinMaxNode) -> SqlDataSet: # noqa: D102 return SqlDataSet( instance_set=parent_data_set.instance_set.transform(ConvertToMetadata(metadata_instances)), - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=tuple(select_columns), from_source=parent_data_set.checked_sql_select_node, @@ -1438,12 +1438,12 @@ def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode) ) ) gen_uuid_sql_select_column = SqlSelectColumn( - expr=SqlGenerateUuidExpression(), column_alias=output_column_association.column_name + expr=SqlGenerateUuidExpression.create(), column_alias=output_column_association.column_name ) return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description="Add column with generated UUID", select_columns=input_data_set.instance_set.transform( CreateSelectColumnsForInstances(input_data_set_alias, self._column_association_resolver) @@ -1531,10 +1531,10 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S ) base_sql_select_columns = tuple( SqlSelectColumn( - expr=SqlWindowFunctionExpression( + expr=SqlWindowFunctionExpression.create( sql_function=SqlWindowFunction.FIRST_VALUE, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=base_data_set_alias, column_name=base_sql_column_reference.col_ref.column_name, @@ -1542,7 +1542,7 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S ) ], partition_by_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=conversion_data_set_alias, column_name=column, @@ -1552,7 +1552,7 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S ], order_by_args=[ SqlWindowOrderByArgument( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=base_data_set_alias, column_name=base_time_dimension_column_name, @@ -1574,7 +1574,7 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S # Deduplicate the fanout results conversion_unique_key_select_columns = tuple( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=conversion_data_set_alias, column_name=column_name, @@ -1587,14 +1587,14 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S additional_conversion_select_columns = conversion_data_set_output_instance_set.transform( CreateSelectColumnsForInstances(conversion_data_set_alias, self._column_association_resolver) ).as_tuple() - deduped_sql_select_node = SqlSelectStatementNode( + deduped_sql_select_node = SqlSelectStatementNode.create( description=f"Dedupe the fanout with {','.join(spec.qualified_name for spec in node.unique_identifier_keys)} in the conversion data set", select_columns=base_sql_select_columns + conversion_unique_key_select_columns + additional_conversion_select_columns, from_source=base_data_set.checked_sql_select_node, from_source_alias=base_data_set_alias, - joins_descs=(sql_join_description,), + join_descs=(sql_join_description,), distinct=True, ) @@ -1605,7 +1605,7 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S ) return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=output_instance_set.transform( CreateSelectColumnsForInstances(output_data_set_alias, self._column_association_resolver) @@ -1655,7 +1655,7 @@ def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> SqlD ) ) metric_select_column = SqlSelectColumn( - expr=SqlWindowFunctionExpression( + expr=SqlWindowFunctionExpression.create( sql_function=sql_window_function, sql_function_args=[ SqlColumnReferenceExpression.from_table_and_column_names( @@ -1687,7 +1687,7 @@ def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> SqlD ).as_tuple() + ( metric_select_column, ) - subquery = SqlSelectStatementNode( + subquery = SqlSelectStatementNode.create( description="Window Function for Metric Re-aggregation", select_columns=subquery_select_columns, from_source=from_data_set.checked_sql_select_node, @@ -1700,7 +1700,7 @@ def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> SqlD ).as_tuple() return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description="Re-aggregate Metric via Group By", select_columns=outer_query_select_columns, from_source=subquery, diff --git a/metricflow/plan_conversion/instance_converters.py b/metricflow/plan_conversion/instance_converters.py index 0d816209ed..c33fc5bbb3 100644 --- a/metricflow/plan_conversion/instance_converters.py +++ b/metricflow/plan_conversion/instance_converters.py @@ -162,7 +162,7 @@ def _make_sql_column_expression( input_column_name = self._output_to_input_column_mapping[output_column_name] select_columns.append( SqlSelectColumn( - expr=SqlColumnReferenceExpression(SqlColumnReference(self._table_alias, input_column_name)), + expr=SqlColumnReferenceExpression.create(SqlColumnReference(self._table_alias, input_column_name)), column_alias=output_column_name, ) ) @@ -223,7 +223,7 @@ def _make_sql_column_expression_to_aggregate_measure( measure = self._semantic_model_lookup.get_measure(measure_instance.spec.reference) aggregation_type = measure.agg - expression_to_get_measure = SqlColumnReferenceExpression( + expression_to_get_measure = SqlColumnReferenceExpression.create( SqlColumnReference(self._table_alias, column_name_in_table) ) @@ -824,7 +824,7 @@ def transform(self, instance_set: InstanceSet) -> Tuple[SqlColumnReferenceExpres self._column_resolver.resolve_spec(spec).column_name for spec in instance_set.spec_set.all_specs ] return tuple( - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=self._table_alias, column_name=column_name, @@ -854,7 +854,7 @@ def __init__( # noqa: D107 def _create_select_column(self, spec: InstanceSpec, fill_nulls_with: Optional[int] = None) -> SqlSelectColumn: """Creates the select column for the given spec and the fill value.""" column_name = self._column_resolver.resolve_spec(spec).column_name - column_reference_expression = SqlColumnReferenceExpression( + column_reference_expression = SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias=self._table_alias, column_name=column_name, @@ -864,11 +864,11 @@ def _create_select_column(self, spec: InstanceSpec, fill_nulls_with: Optional[in aggregation_type=AggregationType.MAX, sql_column_expression=column_reference_expression ) if fill_nulls_with is not None: - select_expression = SqlAggregateFunctionExpression( + select_expression = SqlAggregateFunctionExpression.create( sql_function=SqlFunction.COALESCE, sql_function_args=[ select_expression, - SqlStringExpression(str(fill_nulls_with)), + SqlStringExpression.create(str(fill_nulls_with)), ], ) return SqlSelectColumn( diff --git a/metricflow/plan_conversion/node_processor.py b/metricflow/plan_conversion/node_processor.py index 8b68e22705..57f4f1d8ea 100644 --- a/metricflow/plan_conversion/node_processor.py +++ b/metricflow/plan_conversion/node_processor.py @@ -379,7 +379,9 @@ def _add_time_range_constraint( break if constrain_time: processed_nodes.append( - ConstrainTimeRangeNode(parent_node=source_node, time_range_constraint=time_range_constraint) + ConstrainTimeRangeNode.create( + parent_node=source_node, time_range_constraint=time_range_constraint + ) ) else: processed_nodes.append(source_node) @@ -421,7 +423,7 @@ def _add_where_constraint( filtered_nodes.append(source_node) else: filtered_nodes.append( - WhereConstraintNode(parent_node=source_node, where_specs=matching_filter_specs) + WhereConstraintNode.create(parent_node=source_node, where_specs=matching_filter_specs) ) else: filtered_nodes.append(source_node) @@ -531,7 +533,7 @@ def _get_candidates_nodes_for_multi_hop( # filter measures out of joinable_node specs = data_set_of_second_node_that_can_be_joined.instance_set.spec_set - filtered_joinable_node = FilterElementsNode( + filtered_joinable_node = FilterElementsNode.create( parent_node=second_node_that_could_be_joined, include_specs=group_specs_by_type( specs.dimension_specs @@ -552,7 +554,7 @@ def _get_candidates_nodes_for_multi_hop( multi_hop_join_candidates.append( MultiHopJoinCandidate( - node_with_multi_hop_elements=JoinOnEntitiesNode( + node_with_multi_hop_elements=JoinOnEntitiesNode.create( left_node=first_node_that_could_be_joined, join_targets=[ JoinDescription( diff --git a/metricflow/plan_conversion/sql_expression_builders.py b/metricflow/plan_conversion/sql_expression_builders.py index c3567ab163..e5ed18d463 100644 --- a/metricflow/plan_conversion/sql_expression_builders.py +++ b/metricflow/plan_conversion/sql_expression_builders.py @@ -25,7 +25,7 @@ def make_coalesced_expr(table_aliases: Sequence[str], column_alias: str) -> SqlE COALESCE(a.is_instant, b.is_instant) """ if len(table_aliases) == 1: - return SqlColumnReferenceExpression( + return SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias=table_aliases[0], column_name=column_alias, @@ -35,14 +35,14 @@ def make_coalesced_expr(table_aliases: Sequence[str], column_alias: str) -> SqlE columns_to_coalesce: List[SqlExpressionNode] = [] for table_alias in table_aliases: columns_to_coalesce.append( - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias=table_alias, column_name=column_alias, ) ) ) - return SqlAggregateFunctionExpression( + return SqlAggregateFunctionExpression.create( sql_function=SqlFunction.COALESCE, sql_function_args=columns_to_coalesce, ) diff --git a/metricflow/plan_conversion/sql_join_builder.py b/metricflow/plan_conversion/sql_join_builder.py index 7a58689b67..7bc24b7f11 100644 --- a/metricflow/plan_conversion/sql_join_builder.py +++ b/metricflow/plan_conversion/sql_join_builder.py @@ -97,30 +97,30 @@ def make_column_equality_sql_join_description( and_conditions: List[SqlExpressionNode] = [] for column_equality_description in column_equality_descriptions: - left_column = SqlColumnReferenceExpression( + left_column = SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=left_source_alias, column_name=column_equality_description.left_column_alias, ) ) - right_column = SqlColumnReferenceExpression( + right_column = SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=right_source_alias, column_name=column_equality_description.right_column_alias, ) ) - column_equality_expression = SqlComparisonExpression( + column_equality_expression = SqlComparisonExpression.create( left_expr=left_column, comparison=SqlComparison.EQUALS, right_expr=right_column, ) if column_equality_description.treat_nulls_as_equal: - null_comparison_expression = SqlLogicalExpression( + null_comparison_expression = SqlLogicalExpression.create( operator=SqlLogicalOperator.AND, - args=(SqlIsNullExpression(arg=left_column), SqlIsNullExpression(arg=right_column)), + args=(SqlIsNullExpression.create(arg=left_column), SqlIsNullExpression.create(arg=right_column)), ) and_conditions.append( - SqlLogicalExpression( + SqlLogicalExpression.create( operator=SqlLogicalOperator.OR, args=(column_equality_expression, null_comparison_expression) ) ) @@ -135,7 +135,7 @@ def make_column_equality_sql_join_description( elif len(and_conditions) == 1: on_condition = and_conditions[0] else: - on_condition = SqlLogicalExpression(operator=SqlLogicalOperator.AND, args=tuple(and_conditions)) + on_condition = SqlLogicalExpression.create(operator=SqlLogicalOperator.AND, args=tuple(and_conditions)) return SqlJoinDescription( right_source=right_source_node, @@ -287,32 +287,32 @@ def _make_time_window_join_condition( {start_dimension_name} >= metric_time AND ({end_dimension_name} < metric_time OR {end_dimension_name} IS NULL) """ - left_time_column_expr = SqlColumnReferenceExpression( + left_time_column_expr = SqlColumnReferenceExpression.create( SqlColumnReference(table_alias=left_source_alias, column_name=left_source_time_dimension_name) ) - window_start_column_expr = SqlColumnReferenceExpression( + window_start_column_expr = SqlColumnReferenceExpression.create( SqlColumnReference(table_alias=right_source_alias, column_name=window_start_dimension_name) ) - window_end_column_expr = SqlColumnReferenceExpression( + window_end_column_expr = SqlColumnReferenceExpression.create( SqlColumnReference(table_alias=right_source_alias, column_name=window_end_dimension_name) ) - window_start_condition = SqlComparisonExpression( + window_start_condition = SqlComparisonExpression.create( left_expr=left_time_column_expr, comparison=SqlComparison.GREATER_THAN_OR_EQUALS, right_expr=window_start_column_expr, ) - window_end_by_time = SqlComparisonExpression( + window_end_by_time = SqlComparisonExpression.create( left_expr=left_time_column_expr, comparison=SqlComparison.LESS_THAN, right_expr=window_end_column_expr, ) - window_end_is_null = SqlIsNullExpression(window_end_column_expr) - window_end_condition = SqlLogicalExpression( + window_end_is_null = SqlIsNullExpression.create(window_end_column_expr) + window_end_condition = SqlLogicalExpression.create( operator=SqlLogicalOperator.OR, args=(window_end_by_time, window_end_is_null) ) - return SqlLogicalExpression( + return SqlLogicalExpression.create( operator=SqlLogicalOperator.AND, args=(window_start_condition, window_end_condition) ) @@ -346,7 +346,7 @@ def make_join_description_for_combining_datasets( for colname in column_names ] on_condition = ( - SqlLogicalExpression(operator=SqlLogicalOperator.AND, args=tuple(equality_exprs)) + SqlLogicalExpression.create(operator=SqlLogicalOperator.AND, args=tuple(equality_exprs)) if len(equality_exprs) > 1 else equality_exprs[0] ) @@ -403,10 +403,10 @@ def _make_equality_expression_for_full_outer_join( The latter scenario consolidates the rows keyed by 'c' into a single entry. """ - return SqlComparisonExpression( + return SqlComparisonExpression.create( left_expr=make_coalesced_expr(table_aliases_in_coalesce, column_alias), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias=right_table_alias, column_name=column_alias, @@ -430,13 +430,13 @@ def _make_time_range_window_join_condition( """ if window or grain_to_date: assert_exactly_one_arg_set(window=window, grain_to_date=grain_to_date) - base_column_expr = SqlColumnReferenceExpression( + base_column_expr = SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=base_data_set.alias, column_name=base_data_set.metric_time_column_name, ) ) - time_comparison_column_expr = SqlColumnReferenceExpression( + time_comparison_column_expr = SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=time_comparison_dataset.alias, column_name=time_comparison_dataset.metric_time_column_name, @@ -445,7 +445,7 @@ def _make_time_range_window_join_condition( # Comparison expression against the endpoint of the cumulative time range, # meaning the base metrc time must always be BEFORE the comparison metric time - end_of_range_comparison_expression = SqlComparisonExpression( + end_of_range_comparison_expression = SqlComparisonExpression.create( left_expr=base_column_expr, comparison=SqlComparison.LESS_THAN_OR_EQUALS, right_expr=time_comparison_column_expr, @@ -453,10 +453,10 @@ def _make_time_range_window_join_condition( comparison_expressions: List[SqlComparisonExpression] = [end_of_range_comparison_expression] if window: - start_of_range_comparison_expr = SqlComparisonExpression( + start_of_range_comparison_expr = SqlComparisonExpression.create( left_expr=base_column_expr, comparison=SqlComparison.GREATER_THAN, - right_expr=SqlSubtractTimeIntervalExpression( + right_expr=SqlSubtractTimeIntervalExpression.create( arg=time_comparison_column_expr, count=window.count, granularity=window.granularity, @@ -464,14 +464,16 @@ def _make_time_range_window_join_condition( ) comparison_expressions.append(start_of_range_comparison_expr) elif grain_to_date: - start_of_range_comparison_expr = SqlComparisonExpression( + start_of_range_comparison_expr = SqlComparisonExpression.create( left_expr=base_column_expr, comparison=SqlComparison.GREATER_THAN_OR_EQUALS, - right_expr=SqlDateTruncExpression(arg=time_comparison_column_expr, time_granularity=grain_to_date), + right_expr=SqlDateTruncExpression.create( + arg=time_comparison_column_expr, time_granularity=grain_to_date + ), ) comparison_expressions.append(start_of_range_comparison_expr) - return SqlLogicalExpression( + return SqlLogicalExpression.create( operator=SqlLogicalOperator.AND, args=tuple(comparison_expressions), ) @@ -537,23 +539,23 @@ def make_join_to_time_spine_join_description( parent_alias: str, ) -> SqlJoinDescription: """Build join expression used to join a metric to a time spine dataset.""" - left_expr: SqlExpressionNode = SqlColumnReferenceExpression( + left_expr: SqlExpressionNode = SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=agg_time_dimension_column_name) ) if node.offset_window: - left_expr = SqlSubtractTimeIntervalExpression( + left_expr = SqlSubtractTimeIntervalExpression.create( arg=left_expr, count=node.offset_window.count, granularity=node.offset_window.granularity ) elif node.offset_to_grain: - left_expr = SqlDateTruncExpression(time_granularity=node.offset_to_grain, arg=left_expr) + left_expr = SqlDateTruncExpression.create(time_granularity=node.offset_to_grain, arg=left_expr) return SqlJoinDescription( right_source=parent_sql_select_node, right_source_alias=parent_alias, - on_condition=SqlComparisonExpression( + on_condition=SqlComparisonExpression.create( left_expr=left_expr, comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias=parent_alias, column_name=agg_time_dimension_column_name) ), ), diff --git a/metricflow/sql/optimizer/column_pruner.py b/metricflow/sql/optimizer/column_pruner.py index 75c62ca7b8..61bd283bf1 100644 --- a/metricflow/sql/optimizer/column_pruner.py +++ b/metricflow/sql/optimizer/column_pruner.py @@ -98,12 +98,12 @@ def _prune_columns_from_grandparents( else: pruned_join_descriptions.append(join_description) - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description=node.description, select_columns=pruned_select_columns, from_source=pruned_from_source, from_source_alias=node.from_source_alias, - joins_descs=tuple(pruned_join_descriptions), + join_descs=tuple(pruned_join_descriptions), group_bys=node.group_bys, order_bys=node.order_bys, where=node.where, @@ -178,12 +178,12 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP ) ) - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description=node.description, select_columns=tuple(pruned_select_columns), from_source=pruned_from_source, from_source_alias=node.from_source_alias, - joins_descs=tuple(pruned_join_descriptions), + join_descs=tuple(pruned_join_descriptions), group_bys=node.group_bys, order_bys=node.order_bys, where=node.where, @@ -200,7 +200,7 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> Sq return node def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D102 - return SqlCreateTableAsNode( + return SqlCreateTableAsNode.create( sql_table=node.sql_table, parent_node=node.parent_node.accept(self), ) diff --git a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py index 18ffb95a99..e587b6510f 100644 --- a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py +++ b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py @@ -65,7 +65,7 @@ def combine_wheres(self, additional_where_clauses: List[SqlExpressionNode]) -> O if len(all_where_clauses) == 1: return all_where_clauses[0] elif len(all_where_clauses) > 1: - return SqlLogicalExpression( + return SqlLogicalExpression.create( operator=SqlLogicalOperator.AND, args=tuple(all_where_clauses), ) @@ -93,12 +93,12 @@ def _reduce_parents( node: SqlSelectStatementNode, ) -> SqlSelectStatementNode: """Apply the reducing operation to the parent select statements.""" - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description=node.description, select_columns=node.select_columns, from_source=node.from_source.accept(self), from_source_alias=node.from_source_alias, - joins_descs=tuple( + join_descs=tuple( SqlJoinDescription( right_source=x.right_source.accept(self), right_source_alias=x.right_source_alias, @@ -401,7 +401,7 @@ def _rewrite_where( # For type checking. The above conditionals should ensure the below. assert node_where assert parent_node_where - return SqlLogicalExpression(operator=SqlLogicalOperator.AND, args=(node_where, parent_node_where)) + return SqlLogicalExpression.create(operator=SqlLogicalOperator.AND, args=(node_where, parent_node_where)) @staticmethod def _find_matching_select_column( @@ -568,12 +568,12 @@ def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementN for x in new_join_descs ] - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description=node.description, select_columns=tuple(clauses_to_rewrite.select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(new_join_descs), + join_descs=tuple(new_join_descs), group_bys=tuple(clauses_to_rewrite.group_bys), order_bys=tuple(clauses_to_rewrite.order_bys), where=clauses_to_rewrite.combine_wheres(additional_where_clauses), @@ -656,7 +656,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP new_order_bys.append( SqlOrderByDescription( - expr=SqlColumnAliasReferenceExpression(column_alias=matching_select_column.column_alias), + expr=SqlColumnAliasReferenceExpression.create(column_alias=matching_select_column.column_alias), desc=order_by_item.desc, ) ) @@ -681,14 +681,14 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP elif parent_select_node.group_bys: new_group_bys = parent_select_node.group_bys - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="\n".join([parent_select_node.description, node_with_reduced_parents.description]), select_columns=SqlRewritingSubQueryReducerVisitor._rewrite_select_columns( old_select_columns=node.select_columns, column_replacements=column_replacements ), from_source=parent_select_node.from_source, from_source_alias=parent_select_node.from_source_alias, - joins_descs=parent_select_node.join_descs, + join_descs=parent_select_node.join_descs, group_bys=new_group_bys, order_bys=tuple(new_order_bys), where=SqlRewritingSubQueryReducerVisitor._rewrite_where( @@ -707,7 +707,7 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> Sq return node def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D102 - return SqlCreateTableAsNode( + return SqlCreateTableAsNode.create( sql_table=node.sql_table, parent_node=node.parent_node.accept(self), ) @@ -735,7 +735,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP if matching_select_column: new_group_bys.append( SqlSelectColumn( - expr=SqlColumnAliasReferenceExpression(column_alias=matching_select_column.column_alias), + expr=SqlColumnAliasReferenceExpression.create(column_alias=matching_select_column.column_alias), column_alias=matching_select_column.column_alias, ) ) @@ -743,12 +743,12 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP logger.info(f"Did not find matching select for {group_by} in:\n{indent(node.structure_text())}") new_group_bys.append(group_by) - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description=node.description, select_columns=node.select_columns, from_source=node.from_source.accept(self), from_source_alias=node.from_source_alias, - joins_descs=tuple( + join_descs=tuple( SqlJoinDescription( right_source=x.right_source.accept(self), right_source_alias=x.right_source_alias, @@ -771,7 +771,7 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> Sq return node def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D102 - return SqlCreateTableAsNode( + return SqlCreateTableAsNode.create( sql_table=node.sql_table, parent_node=node.parent_node.accept(self), ) diff --git a/metricflow/sql/optimizer/sub_query_reducer.py b/metricflow/sql/optimizer/sub_query_reducer.py index be649142ff..a3b440cc67 100644 --- a/metricflow/sql/optimizer/sub_query_reducer.py +++ b/metricflow/sql/optimizer/sub_query_reducer.py @@ -27,12 +27,12 @@ def _reduce_parents( node: SqlSelectStatementNode, ) -> SqlSelectStatementNode: """Apply the reducing operation to the parent select statements.""" - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description=node.description, select_columns=node.select_columns, from_source=node.from_source.accept(self), from_source_alias=node.from_source_alias, - joins_descs=tuple( + join_descs=tuple( SqlJoinDescription( right_source=x.right_source.accept(self), right_source_alias=x.right_source_alias, @@ -158,7 +158,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP return node_with_reduced_parents new_order_by.append( SqlOrderByDescription( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=table_alias_in_parent, column_name=order_by_item_expr.col_ref.column_name, @@ -175,12 +175,12 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP elif parent_select_node.limit is not None: new_limit = min(new_limit, parent_select_node.limit) - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="\n".join([parent_select_node.description, node_with_reduced_parents.description]), select_columns=parent_select_node.select_columns, from_source=parent_select_node.from_source, from_source_alias=parent_select_node.from_source_alias, - joins_descs=parent_select_node.join_descs, + join_descs=parent_select_node.join_descs, group_bys=parent_select_node.group_bys, order_bys=tuple(new_order_by), where=parent_select_node.where, @@ -195,7 +195,7 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> Sq return node def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D102 - return SqlCreateTableAsNode( + return SqlCreateTableAsNode.create( sql_table=node.sql_table, parent_node=node.parent_node.accept(self), ) diff --git a/metricflow/sql/optimizer/table_alias_simplifier.py b/metricflow/sql/optimizer/table_alias_simplifier.py index baebd8eefd..9f32cbefa1 100644 --- a/metricflow/sql/optimizer/table_alias_simplifier.py +++ b/metricflow/sql/optimizer/table_alias_simplifier.py @@ -26,7 +26,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP should_simplify_table_aliases = len(node.parent_nodes) <= 1 if should_simplify_table_aliases: - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description=node.description, select_columns=tuple( SqlSelectColumn(expr=x.expr.rewrite(should_render_table_alias=False), column_alias=x.column_alias) @@ -47,12 +47,12 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP distinct=node.distinct, ) - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description=node.description, select_columns=node.select_columns, from_source=node.from_source.accept(self), from_source_alias=node.from_source_alias, - joins_descs=tuple( + join_descs=tuple( SqlJoinDescription( right_source=x.right_source.accept(self), right_source_alias=x.right_source_alias, @@ -75,7 +75,7 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> Sq return node def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D102 - return SqlCreateTableAsNode( + return SqlCreateTableAsNode.create( sql_table=node.sql_table, parent_node=node.parent_node.accept(self), ) diff --git a/metricflow/sql/sql_exprs.py b/metricflow/sql/sql_exprs.py index f507bd2969..7afdf8c026 100644 --- a/metricflow/sql/sql_exprs.py +++ b/metricflow/sql/sql_exprs.py @@ -16,18 +16,16 @@ from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix -from metricflow_semantics.dag.mf_dag import DagNode, DisplayedProperty, NodeId +from metricflow_semantics.dag.mf_dag import DagNode, DisplayedProperty from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters from metricflow_semantics.visitor import Visitable, VisitorOutputT +from typing_extensions import override -class SqlExpressionNode(DagNode, Visitable, ABC): +@dataclass(frozen=True, order=True) +class SqlExpressionNode(DagNode["SqlExpressionNode"], Visitable, ABC): """An SQL expression like my_table.my_column, CONCAT(a, b) or 1 + 1 that evaluates to a value.""" - def __init__(self, node_id: NodeId, parent_nodes: List[SqlExpressionNode]) -> None: # noqa: D107 - self._parent_nodes = parent_nodes - super().__init__(node_id=node_id) - @property @abstractmethod def requires_parenthesis(self) -> bool: @@ -35,7 +33,7 @@ def requires_parenthesis(self) -> bool: Useful for string expressions where we can't infer the structure. For example, in rendering - SqlMathExpression(operator="*", left_expr=SqlStringExpression("a"), right_expr=SqlStringExpression("b + c") + SqlMathExpression(operator="*", left_expr=SqlStringExpression.create("a"), right_expr=SqlStringExpression.create("b + c") this can be used to differentiate between @@ -57,10 +55,6 @@ def bind_parameters(self) -> SqlBindParameters: """ return SqlBindParameters() - @property - def parent_nodes(self) -> Sequence[SqlExpressionNode]: # noqa: D102 - return self._parent_nodes - @property def as_column_reference_expression(self) -> Optional[SqlColumnReferenceExpression]: """If this is a column reference expression, return self.""" @@ -146,7 +140,7 @@ def contains_aggregate_exprs(self) -> bool: # noqa: D102 class SqlColumnReplacements: - """When re-writing column references in expressions, this storing the mapping.""" + """When re-writing column references in expressions, this stores the mapping.""" def __init__(self, column_replacements: Dict[SqlColumnReference, SqlExpressionNode]) -> None: # noqa: D107 self._column_replacements = column_replacements @@ -236,35 +230,41 @@ def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> VisitorOu pass +@dataclass(frozen=True) class SqlStringExpression(SqlExpressionNode): """An SQL expression in a string format, so it lacks information about the structure. These are convenient to use, but because structure is lacking, it can't be easily handled for DB rendering and can impede optimizations. + + Attributes: + sql_expr: The SQL in string form. + bind_parameters: See SqlExpressionNode.bind_parameters + requires_parenthesis: Whether this should be rendered with () if nested in another expression. + used_columns: If set, indicates that the expression represented by the string only uses those columns. e.g. + sql_expr="a + b", used_columns=["a", "b"]. This may be used by optimizers, and if specified, it must be + complete. e.g. sql_expr="a + b + c", used_columns=["a", "b"] will cause problems. """ - def __init__( - self, + sql_expr: str + bind_parameters: SqlBindParameters = SqlBindParameters() + requires_parenthesis: bool = True + used_columns: Optional[Tuple[str, ...]] = None + + @staticmethod + def create( # noqa: D102 sql_expr: str, - bind_parameters: Optional[SqlBindParameters] = None, + bind_parameters: SqlBindParameters = SqlBindParameters(), requires_parenthesis: bool = True, used_columns: Optional[Tuple[str, ...]] = None, - ) -> None: - """Constructor. - - Args: - sql_expr: The SQL in string form. - bind_parameters: See SqlExpressionNode.bind_parameters - requires_parenthesis: Whether this should be rendered with () if nested in another expression. - used_columns: If set, indicates that the expression represented by the string only uses those columns. e.g. - sql_expr="a + b", used_columns=["a", "b"]. This may be used by optimizers, and if specified, it must be - complete. e.g. sql_expr="a + b + c", used_columns=["a", "b"] will cause problems. - """ - self._sql_expr = sql_expr - self._bind_parameters = bind_parameters or SqlBindParameters() - self._requires_parenthesis = requires_parenthesis - self._used_columns = used_columns - super().__init__(node_id=self.create_unique_id(), parent_nodes=[]) + ) -> SqlStringExpression: + return SqlStringExpression( + parent_nodes=(), + sql_expr=sql_expr, + bind_parameters=bind_parameters, + requires_parenthesis=requires_parenthesis, + used_columns=used_columns, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -275,29 +275,15 @@ def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOu @property def description(self) -> str: # noqa: D102 - return f"String SQL Expression: {self._sql_expr}" + return f"String SQL Expression: {self.sql_expr}" @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 - return tuple(super().displayed_properties) + (DisplayedProperty("sql_expr", self._sql_expr),) - - @property - def sql_expr(self) -> str: # noqa: D102 - return self._sql_expr - - @property - def requires_parenthesis(self) -> bool: # noqa: D102 - return self._requires_parenthesis - - @property - def bind_parameters(self) -> SqlBindParameters: # noqa: D102 - return self._bind_parameters + return tuple(super().displayed_properties) + (DisplayedProperty("sql_expr", self.sql_expr),) + @override @property - def used_columns(self) -> Optional[Tuple[str, ...]]: # noqa: D102 - return self._used_columns - - def __repr__(self) -> str: # noqa: D105 + def pretty_format(self) -> str: return f"{self.__class__.__name__}(node_id={self.node_id} sql_expr={self.sql_expr})" def rewrite( # noqa: D102 @@ -328,12 +314,15 @@ def as_string_expression(self) -> Optional[SqlStringExpression]: return self +@dataclass(frozen=True) class SqlStringLiteralExpression(SqlExpressionNode): """A string literal like 'foo'. It shouldn't include delimiters as it should be added during rendering.""" - def __init__(self, literal_value: str) -> None: # noqa: D107 - self._literal_value = literal_value - super().__init__(node_id=self.create_unique_id(), parent_nodes=[]) + literal_value: str + + @staticmethod + def create(literal_value: str) -> SqlStringLiteralExpression: # noqa: D102 + return SqlStringLiteralExpression(parent_nodes=(), literal_value=literal_value) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -344,15 +333,11 @@ def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOu @property def description(self) -> str: # noqa: D102 - return f"String Literal: {self._literal_value}" + return f"String Literal: {self.literal_value}" @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 - return tuple(super().displayed_properties) + (DisplayedProperty("value", self._literal_value),) - - @property - def literal_value(self) -> str: # noqa: D102 - return self._literal_value + return tuple(super().displayed_properties) + (DisplayedProperty("value", self.literal_value),) @property def requires_parenthesis(self) -> bool: # noqa: D102 @@ -390,23 +375,34 @@ class SqlColumnReference: column_name: str +@dataclass(frozen=True) class SqlColumnReferenceExpression(SqlExpressionNode): """An expression that evaluates to the value of a column in one of the sources in the select query. e.g. my_table.my_column + + Attributes: + col_ref: the associated column reference. + should_render_table_alias: When converting this to SQL text, whether the table alias needed to be included. + e.g. "foo.bar" vs "bar". """ - def __init__(self, col_ref: SqlColumnReference, should_render_table_alias: bool = True) -> None: - """Constructor. + col_ref: SqlColumnReference + should_render_table_alias: bool - Args: - col_ref: the associated column reference. - should_render_table_alias: When converting this to SQL text, whether the table alias needed to be included. - e.g. "foo.bar" vs "bar". - """ - self._col_ref = col_ref - self._should_render_table_alias = should_render_table_alias - super().__init__(node_id=self.create_unique_id(), parent_nodes=[]) + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 0 + + @staticmethod + def create( # noqa: D102 + col_ref: SqlColumnReference, should_render_table_alias: bool = True + ) -> SqlColumnReferenceExpression: + return SqlColumnReferenceExpression( + parent_nodes=(), + col_ref=col_ref, + should_render_table_alias=should_render_table_alias, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -415,10 +411,6 @@ def id_prefix(cls) -> IdPrefix: # noqa: D102 def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_column_reference_expr(self) - @property - def col_ref(self) -> SqlColumnReference: # noqa: D102 - return self._col_ref - @property def description(self) -> str: # noqa: D102 return f"Column: {self.col_ref}" @@ -455,17 +447,17 @@ def rewrite( # noqa: D102 return replacement else: if should_render_table_alias is not None: - return SqlColumnReferenceExpression( + return SqlColumnReferenceExpression.create( col_ref=self.col_ref, should_render_table_alias=should_render_table_alias ) return self if should_render_table_alias is not None: - return SqlColumnReferenceExpression( + return SqlColumnReferenceExpression.create( col_ref=self.col_ref, should_render_table_alias=should_render_table_alias ) - return SqlColumnReferenceExpression( + return SqlColumnReferenceExpression.create( col_ref=self.col_ref, should_render_table_alias=self.should_render_table_alias ) @@ -473,10 +465,6 @@ def rewrite( # noqa: D102 def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 return SqlExpressionTreeLineage(column_reference_exprs=(self,)) - @property - def should_render_table_alias(self) -> bool: # noqa: D102 - return self._should_render_table_alias - def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 if not isinstance(other, SqlColumnReferenceExpression): return False @@ -484,9 +472,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 @staticmethod def from_table_and_column_names(table_alias: str, column_name: str) -> SqlColumnReferenceExpression: # noqa: D102 - return SqlColumnReferenceExpression(SqlColumnReference(table_alias=table_alias, column_name=column_name)) + return SqlColumnReferenceExpression.create(SqlColumnReference(table_alias=table_alias, column_name=column_name)) +@dataclass(frozen=True) class SqlColumnAliasReferenceExpression(SqlExpressionNode): """An expression that evaluates to the alias of a column, but is not qualified with a table alias. @@ -496,9 +485,14 @@ class SqlColumnAliasReferenceExpression(SqlExpressionNode): ambiguities. """ - def __init__(self, column_alias: str) -> None: # noqa: D107 - self._column_alias = column_alias - super().__init__(node_id=self.create_unique_id(), parent_nodes=[]) + column_alias: str + + @staticmethod + def create(column_alias: str) -> SqlColumnAliasReferenceExpression: # noqa: D102 + return SqlColumnAliasReferenceExpression( + parent_nodes=(), + column_alias=column_alias, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -507,13 +501,9 @@ def id_prefix(cls) -> IdPrefix: # noqa: D102 def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_column_alias_reference_expr(self) - @property - def column_alias(self) -> str: # noqa: D102 - return self._column_alias - @property def description(self) -> str: # noqa: D102 - return f"Unqualified Column: {self._column_alias}" + return f"Unqualified Column: {self.column_alias}" @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 @@ -554,24 +544,29 @@ class SqlComparison(Enum): # noqa: D101 EQUALS = "=" +@dataclass(frozen=True) class SqlComparisonExpression(SqlExpressionNode): """A comparison using >, <, <=, >=, =. e.g. my_table.my_column = a + b + + Attributes: + left_expr: The expression on the left side of the = + comparison: The comparison to use on expressions + right_expr: The expression on the right side of the = """ - def __init__(self, left_expr: SqlExpressionNode, comparison: SqlComparison, right_expr: SqlExpressionNode) -> None: - """Constructor. + left_expr: SqlExpressionNode + comparison: SqlComparison + right_expr: SqlExpressionNode - Args: - left_expr: The expression on the left side of the = - comparison: The comparison to use on expressions - right_expr: The expression on the right side of the = - """ - self._left_expr = left_expr - self._comparison = comparison - self._right_expr = right_expr - super().__init__(node_id=self.create_unique_id(), parent_nodes=[self._left_expr, self._right_expr]) + @staticmethod + def create( # noqa: D102 + left_expr: SqlExpressionNode, comparison: SqlComparison, right_expr: SqlExpressionNode + ) -> SqlComparisonExpression: + return SqlComparisonExpression( + parent_nodes=(left_expr, right_expr), left_expr=left_expr, comparison=comparison, right_expr=right_expr + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -580,17 +575,9 @@ def id_prefix(cls) -> IdPrefix: # noqa: D102 def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_comparison_expr(self) - @property - def left_expr(self) -> SqlExpressionNode: # noqa: D102 - return self._left_expr - - @property - def right_expr(self) -> SqlExpressionNode: # noqa: D102 - return self._right_expr - @property def description(self) -> str: # noqa: D102 - return f"{self._comparison.value} Expression" + return f"{self.comparison.value} Expression" @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 @@ -604,16 +591,12 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 def requires_parenthesis(self) -> bool: # noqa: D102 return True - @property - def comparison(self) -> SqlComparison: # noqa: D102 - return self._comparison - def rewrite( # noqa: D102 self, column_replacements: Optional[SqlColumnReplacements] = None, should_render_table_alias: Optional[bool] = None, ) -> SqlExpressionNode: - return SqlComparisonExpression( + return SqlComparisonExpression.create( left_expr=self.left_expr.rewrite(column_replacements, should_render_table_alias), comparison=self.comparison, right_expr=self.right_expr.rewrite(column_replacements, should_render_table_alias), @@ -651,7 +634,7 @@ class SqlFunction(Enum): @staticmethod def distinct_aggregation_functions() -> Sequence[SqlFunction]: - """Returns a tuple containg all currently-supported DISTINCT type aggregation functions. + """Returns a tuple containing all currently-supported DISTINCT type aggregation functions. This is not a property because properties don't play nicely with static/class methods. """ @@ -733,37 +716,45 @@ def build_expression_from_aggregation_type( """Returns sql function expression depending on aggregation type.""" if aggregation_type is AggregationType.PERCENTILE: assert agg_params is not None, "Agg_params is none, which should have been caught in validation" - return SqlPercentileExpression( + return SqlPercentileExpression.create( sql_column_expression, SqlPercentileExpressionArgument.from_aggregation_parameters(agg_params) ) else: return SqlAggregateFunctionExpression.from_aggregation_type(aggregation_type, sql_column_expression) +@dataclass(frozen=True) class SqlAggregateFunctionExpression(SqlFunctionExpression): - """An aggregate function expression like SUM(1).""" + """An aggregate function expression like SUM(1). + + Attributes: + sql_function: The function that this represents. + sql_function_args: The arguments that should go into the function. e.g. for "CONCAT(a, b)", the arg + expressions should be "a" and "b". + """ + + sql_function: SqlFunction + sql_function_args: Tuple[SqlExpressionNode, ...] @staticmethod def from_aggregation_type( aggregation_type: AggregationType, sql_column_expression: SqlColumnReferenceExpression ) -> SqlAggregateFunctionExpression: """Given the aggregation type, return an SQL function expression that does that aggregation on the given col.""" - return SqlAggregateFunctionExpression( + return SqlAggregateFunctionExpression.create( sql_function=SqlFunction.from_aggregation_type(aggregation_type=aggregation_type), sql_function_args=[sql_column_expression], ) - def __init__(self, sql_function: SqlFunction, sql_function_args: List[SqlExpressionNode]) -> None: - """Constructor. - - Args: - sql_function: The function that this represents. - sql_function_args: The arguments that should go into the function. e.g. for "CONCAT(a, b)", the arg - expressions should be "a" and "b". - """ - self._sql_function = sql_function - self._sql_function_args = sql_function_args - super().__init__(node_id=self.create_unique_id(), parent_nodes=sql_function_args) + @staticmethod + def create( # noqa: D102 + sql_function: SqlFunction, sql_function_args: Sequence[SqlExpressionNode] + ) -> SqlAggregateFunctionExpression: + return SqlAggregateFunctionExpression( + parent_nodes=tuple(sql_function_args), + sql_function=sql_function, + sql_function_args=tuple(sql_function_args), + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -778,7 +769,7 @@ def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOu @property def description(self) -> str: # noqa: D102 - return f"{self._sql_function.value} Expression" + return f"{self.sql_function.value} Expression" @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 @@ -789,14 +780,8 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 ) @property - def sql_function(self) -> SqlFunction: # noqa: D102 - return self._sql_function - - @property - def sql_function_args(self) -> List[SqlExpressionNode]: # noqa: D102 - return self._sql_function_args - - def __repr__(self) -> str: # noqa: D105 + @override + def pretty_format(self) -> str: # noqa: D105 return f"{self.__class__.__name__}(node_id={self.node_id}, sql_function={self.sql_function.name})" def rewrite( # noqa: D102 @@ -804,7 +789,7 @@ def rewrite( # noqa: D102 column_replacements: Optional[SqlColumnReplacements] = None, should_render_table_alias: Optional[bool] = None, ) -> SqlExpressionNode: - return SqlAggregateFunctionExpression( + return SqlAggregateFunctionExpression.create( sql_function=self.sql_function, sql_function_args=[ x.rewrite(column_replacements, should_render_table_alias) for x in self.sql_function_args @@ -872,21 +857,28 @@ def from_aggregation_parameters(agg_params: MeasureAggregationParameters) -> Sql ) +@dataclass(frozen=True) class SqlPercentileExpression(SqlFunctionExpression): - """A percentile aggregation expression.""" + """A percentile aggregation expression. - def __init__(self, order_by_arg: SqlExpressionNode, percentile_args: SqlPercentileExpressionArgument) -> None: - """Constructor. + Attributes: + order_by_arg: The expression that should go into the function. e.g. for "percentile_cont(col, 0.1)", the arg + expressions should be "col". + percentile_args: Auxillary information including percentile value and type. + """ - Args: - order_by_arg: The expression that should go into the function. e.g. for "percentile_cont(col, 0.1)", the arg - expressions should be "col". - percentile_args: Auxillary information including percentile value and type. - """ - self._order_by_arg = order_by_arg - self._percentile_args = percentile_args + order_by_arg: SqlExpressionNode + percentile_args: SqlPercentileExpressionArgument - super().__init__(node_id=self.create_unique_id(), parent_nodes=[order_by_arg]) + @staticmethod + def create( # noqa: D102 + order_by_arg: SqlExpressionNode, percentile_args: SqlPercentileExpressionArgument + ) -> SqlPercentileExpression: + return SqlPercentileExpression( + parent_nodes=(order_by_arg,), + order_by_arg=order_by_arg, + percentile_args=percentile_args, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -896,40 +888,32 @@ def id_prefix(cls) -> IdPrefix: # noqa: D102 def requires_parenthesis(self) -> bool: # noqa: D102 return False - @property - def order_by_arg(self) -> SqlExpressionNode: # noqa: D102 - return self._order_by_arg - - @property - def percentile_args(self) -> SqlPercentileExpressionArgument: # noqa: D102 - return self._percentile_args - def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_percentile_expr(self) @property def description(self) -> str: # noqa: D102 - return f"{self._percentile_args.function_type.value} Percentile({self._percentile_args.percentile}) Expression" + return f"{self.percentile_args.function_type.value} Percentile({self.percentile_args.percentile}) Expression" @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 return ( tuple(super().displayed_properties) - + (DisplayedProperty("argument", self._order_by_arg),) - + (DisplayedProperty("percentile_args", self._percentile_args),) + + (DisplayedProperty("argument", self.order_by_arg),) + + (DisplayedProperty("percentile_args", self.percentile_args),) ) def __repr__(self) -> str: # noqa: D105 - return f"{self.__class__.__name__}(node_id={self.node_id}, percentile={self._percentile_args.percentile}, function_type={self._percentile_args.function_type.value})" + return f"{self.__class__.__name__}(node_id={self.node_id}, percentile={self.percentile_args.percentile}, function_type={self.percentile_args.function_type.value})" def rewrite( # noqa: D102 self, column_replacements: Optional[SqlColumnReplacements] = None, should_render_table_alias: Optional[bool] = None, ) -> SqlExpressionNode: - return SqlPercentileExpression( - order_by_arg=self._order_by_arg.rewrite(column_replacements, should_render_table_alias), - percentile_args=self._percentile_args, + return SqlPercentileExpression.create( + order_by_arg=self.order_by_arg.rewrite(column_replacements, should_render_table_alias), + percentile_args=self.percentile_args, ) @property @@ -945,7 +929,7 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 if not isinstance(other, SqlPercentileExpression): return False - return self._percentile_args == other._percentile_args and self._parents_match(other) + return self.percentile_args == other.percentile_args and self._parents_match(other) class SqlWindowFunction(Enum): @@ -1000,30 +984,31 @@ def suffix(self) -> str: return " ".join(result) +@dataclass(frozen=True) class SqlWindowFunctionExpression(SqlFunctionExpression): - """A window function expression like SUM(foo) OVER bar.""" + """A window function expression like SUM(foo) OVER bar. + + Attributes: + sql_function: The function that this represents. + sql_function_args: The arguments that should go into the function. e.g. for "CONCAT(a, b)", the arg + expressions should be "a" and "b". + partition_by_args: The arguments to partition the rows. e.g. PARTITION BY expr1, expr2, + the args are "expr1", "expr2". + order_by_args: The expr to order the partitions by. + """ - def __init__( - self, + sql_function: SqlWindowFunction + sql_function_args: Sequence[SqlExpressionNode] + partition_by_args: Sequence[SqlExpressionNode] + order_by_args: Sequence[SqlWindowOrderByArgument] + + @staticmethod + def create( # noqa: D102 sql_function: SqlWindowFunction, sql_function_args: Sequence[SqlExpressionNode] = (), partition_by_args: Sequence[SqlExpressionNode] = (), order_by_args: Sequence[SqlWindowOrderByArgument] = (), - ) -> None: - """Constructor. - - Args: - sql_function: The function that this represents. - sql_function_args: The arguments that should go into the function. e.g. for "CONCAT(a, b)", the arg - expressions should be "a" and "b". - partition_by_args: The arguments to partition the rows. e.g. PARTITION BY expr1, expr2, - the args are "expr1", "expr2". - order_by_args: The expr to order the partitions by. - """ - self._sql_function = sql_function - self._sql_function_args = tuple(sql_function_args) - self._partition_by_args = tuple(partition_by_args) - self._order_by_args = order_by_args + ) -> SqlWindowFunctionExpression: parent_nodes: List[SqlExpressionNode] = [] if sql_function_args: parent_nodes.extend(sql_function_args) @@ -1031,7 +1016,13 @@ def __init__( parent_nodes.extend(partition_by_args) if order_by_args: parent_nodes.extend([x.expr for x in order_by_args]) - super().__init__(node_id=self.create_unique_id(), parent_nodes=parent_nodes) + return SqlWindowFunctionExpression( + parent_nodes=tuple(parent_nodes), + sql_function=sql_function, + sql_function_args=tuple(sql_function_args), + partition_by_args=tuple(partition_by_args), + order_by_args=tuple(order_by_args), + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -1046,7 +1037,7 @@ def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOu @property def description(self) -> str: # noqa: D102 - return f"{self._sql_function.value} Window Function Expression" + return f"{self.sql_function.value} Window Function Expression" @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 @@ -1058,27 +1049,13 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 + tuple(DisplayedProperty("order_by_argument", x) for x in self.order_by_args) ) - @property - def sql_function(self) -> SqlWindowFunction: # noqa: D102 - return self._sql_function - - @property - def sql_function_args(self) -> Sequence[SqlExpressionNode]: # noqa: D102 - return self._sql_function_args - - @property - def partition_by_args(self) -> Sequence[SqlExpressionNode]: # noqa: D102 - return self._partition_by_args - - @property - def order_by_args(self) -> Sequence[SqlWindowOrderByArgument]: # noqa: D102 - return self._order_by_args - @property def is_aggregate_function(self) -> bool: # noqa: D102 return False - def __repr__(self) -> str: # noqa: D105 + @property + @override + def pretty_format(self) -> str: # noqa: D105 return f"{self.__class__.__name__}(node_id={self.node_id}, sql_function={self.sql_function.name})" def rewrite( # noqa: D102 @@ -1086,7 +1063,7 @@ def rewrite( # noqa: D102 column_replacements: Optional[SqlColumnReplacements] = None, should_render_table_alias: Optional[bool] = None, ) -> SqlExpressionNode: - return SqlWindowFunctionExpression( + return SqlWindowFunctionExpression.create( sql_function=self.sql_function, sql_function_args=[ x.rewrite(column_replacements, should_render_table_alias) for x in self.sql_function_args @@ -1124,11 +1101,15 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 ) +@dataclass(frozen=True) class SqlNullExpression(SqlExpressionNode): """Represents NULL.""" - def __init__(self) -> None: # noqa: D107 - super().__init__(node_id=self.create_unique_id(), parent_nodes=[]) + @staticmethod + def create() -> SqlNullExpression: # noqa: D102 + return SqlNullExpression( + parent_nodes=(), + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -1170,12 +1151,20 @@ class SqlLogicalOperator(Enum): OR = "OR" +@dataclass(frozen=True) class SqlLogicalExpression(SqlExpressionNode): """A logical expression like "a AND b AND c".""" - def __init__(self, operator: SqlLogicalOperator, args: Tuple[SqlExpressionNode, ...]) -> None: # noqa: D107 - self._operator = operator - super().__init__(node_id=self.create_unique_id(), parent_nodes=list(args)) + operator: SqlLogicalOperator + args: Tuple[SqlExpressionNode, ...] + + @staticmethod + def create(operator: SqlLogicalOperator, args: Sequence[SqlExpressionNode]) -> SqlLogicalExpression: # noqa: D102 + return SqlLogicalExpression( + parent_nodes=tuple(args), + operator=operator, + args=tuple(args), + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -1190,22 +1179,14 @@ def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOu @property def description(self) -> str: # noqa: D102 - return f"Logical Operator {self._operator.value}" - - @property - def args(self) -> Sequence[SqlExpressionNode]: # noqa: D102 - return self.parent_nodes - - @property - def operator(self) -> SqlLogicalOperator: # noqa: D102 - return self._operator + return f"Logical Operator {self.operator.value}" def rewrite( # noqa: D102 self, column_replacements: Optional[SqlColumnReplacements] = None, should_render_table_alias: Optional[bool] = None, ) -> SqlExpressionNode: - return SqlLogicalExpression( + return SqlLogicalExpression.create( operator=self.operator, args=tuple(x.rewrite(column_replacements, should_render_table_alias) for x in self.args), ) @@ -1222,12 +1203,18 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self.operator == other.operator and self._parents_match(other) +@dataclass(frozen=True) class SqlIsNullExpression(SqlExpressionNode): """An IS NULL expression like "foo IS NULL".""" - def __init__(self, arg: SqlExpressionNode) -> None: # noqa: D107 - self._arg = arg - super().__init__(node_id=self.create_unique_id(), parent_nodes=[arg]) + arg: SqlExpressionNode + + @staticmethod + def create(arg: SqlExpressionNode) -> SqlIsNullExpression: # noqa: D102 + return SqlIsNullExpression( + parent_nodes=(arg,), + arg=arg, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -1244,16 +1231,12 @@ def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOu def description(self) -> str: # noqa: D102 return "IS NULL Expression" - @property - def arg(self) -> SqlExpressionNode: # noqa: D102 - return self._arg - def rewrite( # noqa: D102 self, column_replacements: Optional[SqlColumnReplacements] = None, should_render_table_alias: Optional[bool] = None, ) -> SqlExpressionNode: - return SqlIsNullExpression(arg=self.arg.rewrite(column_replacements, should_render_table_alias)) + return SqlIsNullExpression.create(arg=self.arg.rewrite(column_replacements, should_render_table_alias)) @property def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 @@ -1265,6 +1248,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self._parents_match(other) +@dataclass(frozen=True) class SqlSubtractTimeIntervalExpression(SqlExpressionNode): """Represents an interval subtraction from a given timestamp. @@ -1274,16 +1258,22 @@ class SqlSubtractTimeIntervalExpression(SqlExpressionNode): value. """ - def __init__( # noqa: D107 - self, + arg: SqlExpressionNode + count: int + granularity: TimeGranularity + + @staticmethod + def create( # noqa: D102 arg: SqlExpressionNode, count: int, granularity: TimeGranularity, - ) -> None: - super().__init__(node_id=self.create_unique_id(), parent_nodes=[arg]) - self._count = count - self._time_granularity = granularity - self._arg = arg + ) -> SqlSubtractTimeIntervalExpression: + return SqlSubtractTimeIntervalExpression( + parent_nodes=(arg,), + arg=arg, + count=count, + granularity=granularity, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -1300,24 +1290,12 @@ def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOu def description(self) -> str: # noqa: D102 return "Time delta" - @property - def arg(self) -> SqlExpressionNode: # noqa: D102 - return self._arg - - @property - def count(self) -> int: # noqa: D102 - return self._count - - @property - def granularity(self) -> TimeGranularity: # noqa: D102 - return self._time_granularity - def rewrite( # noqa: D102 self, column_replacements: Optional[SqlColumnReplacements] = None, should_render_table_alias: Optional[bool] = None, ) -> SqlExpressionNode: - return SqlSubtractTimeIntervalExpression( + return SqlSubtractTimeIntervalExpression.create( arg=self.arg.rewrite(column_replacements, should_render_table_alias), count=self.count, granularity=self.granularity, @@ -1335,11 +1313,18 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self.count == other.count and self.granularity == other.granularity and self._parents_match(other) +@dataclass(frozen=True) class SqlCastToTimestampExpression(SqlExpressionNode): """Cast to the timestamp type like CAST('2020-01-01' AS TIMESTAMP).""" - def __init__(self, arg: SqlExpressionNode) -> None: # noqa: D107 - super().__init__(node_id=self.create_unique_id(), parent_nodes=[arg]) + arg: SqlExpressionNode + + @staticmethod + def create(arg: SqlExpressionNode) -> SqlCastToTimestampExpression: # noqa: D102 + return SqlCastToTimestampExpression( + parent_nodes=(arg,), + arg=arg, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -1356,17 +1341,12 @@ def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOu def description(self) -> str: # noqa: D102 return "Cast to Timestamp" - @property - def arg(self) -> SqlExpressionNode: # noqa: D102 - assert len(self.parent_nodes) == 1 - return self.parent_nodes[0] - def rewrite( # noqa: D102 self, column_replacements: Optional[SqlColumnReplacements] = None, should_render_table_alias: Optional[bool] = None, ) -> SqlExpressionNode: - return SqlCastToTimestampExpression(arg=self.arg.rewrite(column_replacements, should_render_table_alias)) + return SqlCastToTimestampExpression.create(arg=self.arg.rewrite(column_replacements, should_render_table_alias)) @property def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 @@ -1380,18 +1360,20 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self._parents_match(other) +@dataclass(frozen=True) class SqlDateTruncExpression(SqlExpressionNode): """Apply a date trunc to a column like CAST('2020-01-01' AS TIMESTAMP).""" - def __init__(self, time_granularity: TimeGranularity, arg: SqlExpressionNode) -> None: - """Constructor. + time_granularity: TimeGranularity + arg: SqlExpressionNode - Args: - time_granularity: the granularity to DATE_TRUNC() to. - arg: the value to DATE_TRUNC(). - """ - self._time_granularity = time_granularity - super().__init__(node_id=self.create_unique_id(), parent_nodes=[arg]) + @staticmethod + def create(time_granularity: TimeGranularity, arg: SqlExpressionNode) -> SqlDateTruncExpression: # noqa: D102 + return SqlDateTruncExpression( + parent_nodes=(arg,), + time_granularity=time_granularity, + arg=arg, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -1408,21 +1390,12 @@ def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOu def description(self) -> str: # noqa: D102 return f"DATE_TRUNC() to {self.time_granularity}" - @property - def time_granularity(self) -> TimeGranularity: # noqa: D102 - return self._time_granularity - - @property - def arg(self) -> SqlExpressionNode: # noqa: D102 - assert len(self.parent_nodes) == 1 - return self.parent_nodes[0] - def rewrite( # noqa: D102 self, column_replacements: Optional[SqlColumnReplacements] = None, should_render_table_alias: Optional[bool] = None, ) -> SqlExpressionNode: - return SqlDateTruncExpression( + return SqlDateTruncExpression.create( time_granularity=self.time_granularity, arg=self.arg.rewrite(column_replacements, should_render_table_alias) ) @@ -1438,18 +1411,28 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self.time_granularity == other.time_granularity and self._parents_match(other) +@dataclass(frozen=True) class SqlExtractExpression(SqlExpressionNode): - """Extract a date part from a time expression.""" + """Extract a date part from a time expression. - def __init__(self, date_part: DatePart, arg: SqlExpressionNode) -> None: - """Constructor. + Attributes: + date_part: The date part to extract. + arg: The expression to extract from. + """ - Args: - date_part: the date part to extract. - arg: the expression to extract from. - """ - self._date_part = date_part - super().__init__(node_id=self.create_unique_id(), parent_nodes=[arg]) + date_part: DatePart + arg: SqlExpressionNode + + @staticmethod + def create( # noqa: D102 + date_part: DatePart, + arg: SqlExpressionNode, + ) -> SqlExtractExpression: + return SqlExtractExpression( + parent_nodes=(arg,), + date_part=date_part, + arg=arg, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -1466,21 +1449,12 @@ def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOu def description(self) -> str: # noqa: D102 return f"Extract {self.date_part.name}" - @property - def date_part(self) -> DatePart: # noqa: D102 - return self._date_part - - @property - def arg(self) -> SqlExpressionNode: # noqa: D102 - assert len(self.parent_nodes) == 1 - return self.parent_nodes[0] - def rewrite( # noqa: D102 self, column_replacements: Optional[SqlColumnReplacements] = None, should_render_table_alias: Optional[bool] = None, ) -> SqlExpressionNode: - return SqlExtractExpression( + return SqlExtractExpression.create( date_part=self.date_part, arg=self.arg.rewrite(column_replacements, should_render_table_alias) ) @@ -1496,25 +1470,33 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self.date_part == other.date_part and self._parents_match(other) +@dataclass(frozen=True) class SqlRatioComputationExpression(SqlExpressionNode): """Node for expressing Ratio metrics to allow for appropriate casting to float/double in each engine. - In future we might wish to break this up into a set of nodes, e.g., SqlCastExpression and SqlMathExpression + In the future, we might wish to break this up into a set of nodes, e.g., SqlCastExpression and SqlMathExpression or even add CAST to SqlFunctionExpression. However, at this time the only mathematical operation we encode is division, and we only use that for ratios. Similarly, the only times we do typecasting are when we are coercing timestamps (already handled) or computing ratio metrics. + + Attributes: + numerator: The expression for the numerator in the ratio. + denominator: The expression for the denominator in the ratio. """ - def __init__(self, numerator: SqlExpressionNode, denominator: SqlExpressionNode) -> None: - """Initialize this node for computing a ratio. Expression renderers should handle the casting. + numerator: SqlExpressionNode + denominator: SqlExpressionNode - Args: - numerator: the expression for the numerator in the ratio - denominator: the expression for the denominator in the ratio - """ - self._numerator = numerator - self._denominator = denominator - super().__init__(node_id=self.create_unique_id(), parent_nodes=[numerator, denominator]) + @staticmethod + def create( # noqa: D102 + numerator: SqlExpressionNode, + denominator: SqlExpressionNode, + ) -> SqlRatioComputationExpression: + return SqlRatioComputationExpression( + parent_nodes=(numerator, denominator), + numerator=numerator, + denominator=denominator, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -1531,20 +1513,12 @@ def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOu def description(self) -> str: # noqa: D102 return "Divide numerator by denominator, with appropriate casting" - @property - def numerator(self) -> SqlExpressionNode: # noqa: D102 - return self._numerator - - @property - def denominator(self) -> SqlExpressionNode: # noqa: D102 - return self._denominator - def rewrite( # noqa: D102 self, column_replacements: Optional[SqlColumnReplacements] = None, should_render_table_alias: Optional[bool] = None, ) -> SqlExpressionNode: - return SqlRatioComputationExpression( + return SqlRatioComputationExpression.create( numerator=self.numerator.rewrite(column_replacements, should_render_table_alias), denominator=self.denominator.rewrite(column_replacements, should_render_table_alias), ) @@ -1561,16 +1535,32 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self._parents_match(other) +@dataclass(frozen=True) class SqlBetweenExpression(SqlExpressionNode): - """A BETWEEN clause like `column BETWEEN val1 AND val2`.""" + """A BETWEEN clause like `column BETWEEN val1 AND val2`. + + Attributes: + column_arg: The column or expression to apply the BETWEEN clause. + start_expr: The start expression of the BETWEEN clause. + end_expr: The end expression of the BETWEEN clause. + """ - def __init__( # noqa: D107 - self, column_arg: SqlExpressionNode, start_expr: SqlExpressionNode, end_expr: SqlExpressionNode - ) -> None: - self._column_arg = column_arg - self._start_expr = start_expr - self._end_expr = end_expr - super().__init__(node_id=self.create_unique_id(), parent_nodes=[column_arg, start_expr, end_expr]) + column_arg: SqlExpressionNode + start_expr: SqlExpressionNode + end_expr: SqlExpressionNode + + @staticmethod + def create( # noqa: D102 + column_arg: SqlExpressionNode, + start_expr: SqlExpressionNode, + end_expr: SqlExpressionNode, + ) -> SqlBetweenExpression: + return SqlBetweenExpression( + parent_nodes=(column_arg, start_expr, end_expr), + column_arg=column_arg, + start_expr=start_expr, + end_expr=end_expr, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -1587,24 +1577,12 @@ def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOu def description(self) -> str: # noqa: D102 return "BETWEEN operator" - @property - def column_arg(self) -> SqlExpressionNode: # noqa: D102 - return self._column_arg - - @property - def start_expr(self) -> SqlExpressionNode: # noqa: D102 - return self._start_expr - - @property - def end_expr(self) -> SqlExpressionNode: # noqa: D102 - return self._end_expr - def rewrite( # noqa: D102 self, column_replacements: Optional[SqlColumnReplacements] = None, should_render_table_alias: Optional[bool] = None, ) -> SqlExpressionNode: - return SqlBetweenExpression( + return SqlBetweenExpression.create( column_arg=self.column_arg.rewrite(column_replacements, should_render_table_alias), start_expr=self.start_expr.rewrite(column_replacements, should_render_table_alias), end_expr=self.end_expr.rewrite(column_replacements, should_render_table_alias), @@ -1622,11 +1600,15 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self._parents_match(other) +@dataclass(frozen=True) class SqlGenerateUuidExpression(SqlExpressionNode): - """Renders a sql to generate a random uuid, is non-deterministic..""" + """Renders a SQL to generate a random UUID, which is non-deterministic.""" - def __init__(self) -> None: # noqa: D107 - super().__init__(node_id=self.create_unique_id(), parent_nodes=[]) + @staticmethod + def create() -> SqlGenerateUuidExpression: # noqa: D102 + return SqlGenerateUuidExpression( + parent_nodes=(), + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index af09f023ea..5c07a33fac 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -5,10 +5,10 @@ import logging from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Generic, List, Optional, Sequence, Tuple +from typing import Generic, Optional, Sequence, Tuple from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix -from metricflow_semantics.dag.mf_dag import DagId, DagNode, DisplayedProperty, MetricFlowDag, NodeId +from metricflow_semantics.dag.mf_dag import DagId, DagNode, DisplayedProperty, MetricFlowDag from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.visitor import VisitorOutputT from typing_extensions import override @@ -19,7 +19,8 @@ logger = logging.getLogger(__name__) -class SqlQueryPlanNode(DagNode, ABC): +@dataclass(frozen=True) +class SqlQueryPlanNode(DagNode["SqlQueryPlanNode"], ABC): """Modeling a SQL query plan like a data flow plan as well. In that model: @@ -33,14 +34,6 @@ class SqlQueryPlanNode(DagNode, ABC): Is there an existing library that can do this? """ - def __init__(self, node_id: NodeId, parent_nodes: Sequence[SqlQueryPlanNode]) -> None: # noqa: D107 - self._parent_nodes = parent_nodes - super().__init__(node_id=node_id) - - @property - def parent_nodes(self) -> List[SqlQueryPlanNode]: # noqa: D102 - return list(self._parent_nodes) - @abstractmethod def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: """Called when a visitor needs to visit this node.""" @@ -108,92 +101,78 @@ class SqlOrderByDescription: # noqa: D101 desc: bool +@dataclass(frozen=True) class SqlSelectStatementNode(SqlQueryPlanNode): - """Represents an SQL Select statement.""" + """Represents an SQL Select statement. + + Attributes: + select_columns: The columns to select. + from_source: The source of the data for the select statement. + from_source_alias: Alias for the from source. + join_descs: Descriptions of the joins to perform. + group_bys: The columns to group by. + order_bys: The columns to order by. + where: The where clause expression. + limit: The limit of the number of rows to return. + distinct: Whether the select statement should return distinct rows. + """ - def __init__( # noqa: D107 - self, + _description: str + select_columns: Tuple[SqlSelectColumn, ...] + from_source: SqlQueryPlanNode + from_source_alias: str + join_descs: Tuple[SqlJoinDescription, ...] + group_bys: Tuple[SqlSelectColumn, ...] + order_bys: Tuple[SqlOrderByDescription, ...] + where: Optional[SqlExpressionNode] + limit: Optional[int] + distinct: bool + + @staticmethod + def create( # noqa: D102 description: str, select_columns: Tuple[SqlSelectColumn, ...], from_source: SqlQueryPlanNode, from_source_alias: str, - joins_descs: Tuple[SqlJoinDescription, ...] = (), + join_descs: Tuple[SqlJoinDescription, ...] = (), group_bys: Tuple[SqlSelectColumn, ...] = (), order_bys: Tuple[SqlOrderByDescription, ...] = (), where: Optional[SqlExpressionNode] = None, limit: Optional[int] = None, distinct: bool = False, - ) -> None: - self._description = description - assert select_columns - self._select_columns = select_columns - # Sources that belong in a from clause. CTEs could be captured in a separate field. - self._from_source = from_source - self._from_source_alias = from_source_alias - self._join_descs = joins_descs - self._group_bys = group_bys - self._where = where - self._order_bys = order_bys - self._distinct = distinct - - if limit is not None: - assert limit >= 0 - self._limit = limit - - super().__init__( - node_id=self.create_unique_id(), - parent_nodes=[self._from_source] + [x.right_source for x in self._join_descs], + ) -> SqlSelectStatementNode: + parent_nodes = [from_source] + [x.right_source for x in join_descs] + return SqlSelectStatementNode( + parent_nodes=tuple(parent_nodes), + _description=description, + select_columns=select_columns, + from_source=from_source, + from_source_alias=from_source_alias, + join_descs=join_descs, + group_bys=group_bys, + order_bys=order_bys, + where=where, + limit=limit, + distinct=distinct, ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 return StaticIdPrefix.SQL_PLAN_SELECT_STATEMENT_ID_PREFIX - @property - def description(self) -> str: # noqa: D102 - return self._description - @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 return ( tuple(super().displayed_properties) - + tuple(DisplayedProperty(f"col{i}", column) for i, column in enumerate(self._select_columns)) + + tuple(DisplayedProperty(f"col{i}", column) for i, column in enumerate(self.select_columns)) + (DisplayedProperty("from_source", self.from_source),) - + tuple(DisplayedProperty(f"join_{i}", join_desc) for i, join_desc in enumerate(self._join_descs)) - + tuple(DisplayedProperty(f"group_by{i}", group_by) for i, group_by in enumerate(self._group_bys)) - + (DisplayedProperty("where", self._where),) - + tuple(DisplayedProperty(f"order_by{i}", order_by) for i, order_by in enumerate(self._order_bys)) - + (DisplayedProperty("distinct", self._distinct),) + + tuple(DisplayedProperty(f"join_{i}", join_desc) for i, join_desc in enumerate(self.join_descs)) + + tuple(DisplayedProperty(f"group_by{i}", group_by) for i, group_by in enumerate(self.group_bys)) + + (DisplayedProperty("where", self.where),) + + tuple(DisplayedProperty(f"order_by{i}", order_by) for i, order_by in enumerate(self.order_bys)) + + (DisplayedProperty("distinct", self.distinct),) ) - @property - def select_columns(self) -> Tuple[SqlSelectColumn, ...]: # noqa: D102 - return self._select_columns - - @property - def from_source(self) -> SqlQueryPlanNode: # noqa: D102 - return self._from_source - - @property - def from_source_alias(self) -> str: # noqa: D102 - return self._from_source_alias - - @property - def join_descs(self) -> Tuple[SqlJoinDescription, ...]: # noqa: D102 - return self._join_descs - - @property - def group_bys(self) -> Tuple[SqlSelectColumn, ...]: # noqa: D102 - return self._group_bys - - @property - def where(self) -> Optional[SqlExpressionNode]: # noqa: D102 - return self._where - - @property - def order_bys(self) -> Tuple[SqlOrderByDescription, ...]: # noqa: D102 - return self._order_bys - def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_select_statement_node(self) @@ -201,25 +180,28 @@ def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOut def is_table(self) -> bool: # noqa: D102 return False - @property - def limit(self) -> Optional[int]: # noqa: D102 - return self._limit - @property def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102 return self @property - def distinct(self) -> bool: # noqa: D102 - return self._distinct + @override + def description(self) -> str: + return self._description +@dataclass(frozen=True) class SqlTableFromClauseNode(SqlQueryPlanNode): """An SQL table that can go in the FROM clause.""" - def __init__(self, sql_table: SqlTable) -> None: # noqa: D107 - self._sql_table = sql_table - super().__init__(node_id=self.create_unique_id(), parent_nodes=[]) + sql_table: SqlTable + + @staticmethod + def create(sql_table: SqlTable) -> SqlTableFromClauseNode: # noqa: D102 + return SqlTableFromClauseNode( + parent_nodes=(), + sql_table=sql_table, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -227,19 +209,15 @@ def id_prefix(cls) -> IdPrefix: # noqa: D102 @property def description(self) -> str: # noqa: D102 - return f"Read from {self._sql_table.sql}" + return f"Read from {self.sql_table.sql}" @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 - return tuple(super().displayed_properties) + (DisplayedProperty("table_id", self._sql_table.sql),) + return tuple(super().displayed_properties) + (DisplayedProperty("table_id", self.sql_table.sql),) def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_table_from_clause_node(self) - @property - def sql_table(self) -> SqlTable: # noqa: D102 - return self._sql_table - @property def is_table(self) -> bool: # noqa: D102 return True @@ -249,12 +227,22 @@ def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102 return None +@dataclass(frozen=True) class SqlSelectQueryFromClauseNode(SqlQueryPlanNode): - """An SQL select query that can go in the FROM clause.""" + """An SQL select query that can go in the FROM clause. + + Attributes: + select_query: The SQL select query to include in the FROM clause. + """ - def __init__(self, select_query: str) -> None: # noqa: D107 - self._select_query = select_query - super().__init__(node_id=self.create_unique_id(), parent_nodes=[]) + select_query: str + + @staticmethod + def create(select_query: str) -> SqlSelectQueryFromClauseNode: # noqa: D102 + return SqlSelectQueryFromClauseNode( + parent_nodes=(), + select_query=select_query, + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -267,10 +255,6 @@ def description(self) -> str: # noqa: D102 def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_query_from_clause_node(self) - @property - def select_query(self) -> str: # noqa: D102 - return self._select_query - @property def is_table(self) -> bool: # noqa: D102 return False @@ -280,13 +264,27 @@ def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102 return None +@dataclass(frozen=True) class SqlCreateTableAsNode(SqlQueryPlanNode): - """An SQL select query that can go in the FROM clause.""" + """An SQL node representing a CREATE TABLE AS statement. + + Attributes: + sql_table: The SQL table to create. + parent_node: The parent query plan node. + """ + + sql_table: SqlTable - def __init__(self, sql_table: SqlTable, parent_node: SqlQueryPlanNode) -> None: # noqa: D107 - self._sql_table = sql_table - self._parent_node = parent_node - super().__init__(node_id=self.create_unique_id(), parent_nodes=(self._parent_node,)) + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + + @staticmethod + def create(sql_table: SqlTable, parent_node: SqlQueryPlanNode) -> SqlCreateTableAsNode: # noqa: D102 + return SqlCreateTableAsNode( + parent_nodes=(parent_node,), + sql_table=sql_table, + ) @override def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: @@ -307,20 +305,15 @@ def as_select_node(self) -> Optional[SqlSelectStatementNode]: def description(self) -> str: return f"Create table {repr(self.sql_table.sql)}" + @property + def parent_node(self) -> SqlQueryPlanNode: # noqa: D102 + return self.parent_nodes[0] + @classmethod @override def id_prefix(cls) -> IdPrefix: return StaticIdPrefix.SQL_PLAN_CREATE_TABLE_AS_ID_PREFIX - @property - def sql_table(self) -> SqlTable: - """Return the table that this statement would create.""" - return self._sql_table - - @property - def parent_node(self) -> SqlQueryPlanNode: # noqa: D102 - return self._parent_node - class SqlQueryPlan(MetricFlowDag[SqlQueryPlanNode]): """Model for an SQL Query as a DAG.""" diff --git a/metricflow/validation/data_warehouse_model_validator.py b/metricflow/validation/data_warehouse_model_validator.py index 8ffd7a2a0a..a9d076278f 100644 --- a/metricflow/validation/data_warehouse_model_validator.py +++ b/metricflow/validation/data_warehouse_model_validator.py @@ -201,7 +201,7 @@ def gen_dimension_tasks( spec_filter_tuples.append( ( spec, - FilterElementsNode( + FilterElementsNode.create( parent_node=source_node, include_specs=InstanceSpecSet(dimension_specs=(spec,)) ), ) @@ -214,7 +214,7 @@ def gen_dimension_tasks( spec_filter_tuples.append( ( spec, - FilterElementsNode( + FilterElementsNode.create( parent_node=source_node, include_specs=InstanceSpecSet(time_dimension_specs=(spec,)) ), ) @@ -241,7 +241,7 @@ def gen_dimension_tasks( ) ) - filter_elements_node = FilterElementsNode( + filter_elements_node = FilterElementsNode.create( parent_node=source_node, include_specs=InstanceSpecSet( dimension_specs=dimension_specs, @@ -299,7 +299,7 @@ def gen_entity_tasks( dataset.instance_set.spec_set.entity_specs ) for spec in semantic_model_specs: - filter_elements_node = FilterElementsNode( + filter_elements_node = FilterElementsNode.create( parent_node=source_node, include_specs=InstanceSpecSet(entity_specs=(spec,)) ) semantic_model_sub_tasks.append( @@ -322,7 +322,7 @@ def gen_entity_tasks( ) ) - filter_elements_node = FilterElementsNode( + filter_elements_node = FilterElementsNode.create( parent_node=source_node, include_specs=InstanceSpecSet( entity_specs=tuple(semantic_model_specs), @@ -392,7 +392,7 @@ def gen_measure_tasks( obtained_source_node = source_node_by_measure_spec.get(spec) assert obtained_source_node, f"Unable to find generated source node for measure: {spec.element_name}" - filter_elements_node = FilterElementsNode( + filter_elements_node = FilterElementsNode.create( parent_node=obtained_source_node, include_specs=InstanceSpecSet( measure_specs=(spec,), @@ -419,7 +419,7 @@ def gen_measure_tasks( ) for measure_specs, source_node in measure_specs_source_node_pair: - filter_elements_node = FilterElementsNode( + filter_elements_node = FilterElementsNode.create( parent_node=source_node, include_specs=InstanceSpecSet(measure_specs=measure_specs) ) tasks.append( diff --git a/scripts/ci_tests/metricflow_package_test.py b/scripts/ci_tests/metricflow_package_test.py index 5ca4ef645f..ab8bfb0f49 100644 --- a/scripts/ci_tests/metricflow_package_test.py +++ b/scripts/ci_tests/metricflow_package_test.py @@ -36,7 +36,7 @@ def _data_set_to_read_nodes(data_sets: OrderedDict[str, SemanticModelDataSet]) - # Moved from model_fixtures.py. return_dict: OrderedDict[str, ReadSqlSourceNode] = OrderedDict() for semantic_model_name, data_set in data_sets.items(): - return_dict[semantic_model_name] = ReadSqlSourceNode(data_set) + return_dict[semantic_model_name] = ReadSqlSourceNode.create(data_set) return return_dict diff --git a/tests_metricflow/dataflow/builder/test_node_data_set.py b/tests_metricflow/dataflow/builder/test_node_data_set.py index ee668769f8..8c05459111 100644 --- a/tests_metricflow/dataflow/builder/test_node_data_set.py +++ b/tests_metricflow/dataflow/builder/test_node_data_set.py @@ -68,20 +68,24 @@ def test_no_parent_node_data_set( time_dimension_instances=(), entity_instances=(), ), - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression(SqlColumnReference(table_alias="src", column_name="bookings")), + expr=SqlColumnReferenceExpression.create( + SqlColumnReference(table_alias="src", column_name="bookings") + ), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="src", ), ) - node = ReadSqlSourceNode(data_set=data_set) + node = ReadSqlSourceNode.create(data_set=data_set) assert resolver.get_output_data_set(node).instance_set == data_set.instance_set @@ -102,7 +106,7 @@ def test_joined_node_data_set( # Join "revenue" with "users_latest" to get "user__home_state_latest" revenue_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping["revenue"] users_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping["users_latest"] - join_node = JoinOnEntitiesNode( + join_node = JoinOnEntitiesNode.create( left_node=revenue_node, join_targets=[ JoinDescription( diff --git a/tests_metricflow/dataflow/optimizer/source_scan/test_cm_branch_combiner.py b/tests_metricflow/dataflow/optimizer/source_scan/test_cm_branch_combiner.py index eaceb6465f..096dd30fa9 100644 --- a/tests_metricflow/dataflow/optimizer/source_scan/test_cm_branch_combiner.py +++ b/tests_metricflow/dataflow/optimizer/source_scan/test_cm_branch_combiner.py @@ -27,7 +27,7 @@ def make_dataflow_plan(node: DataflowPlanNode) -> DataflowPlan: # noqa: D103 return DataflowPlan( - sink_nodes=[WriteToResultDataTableNode(node)], + sink_nodes=[WriteToResultDataTableNode.create(node)], plan_id=DagId.from_id_prefix(StaticIdPrefix.OPTIMIZED_DATAFLOW_PLAN_PREFIX), ) @@ -69,11 +69,11 @@ def test_filter_combination( ) -> None: """Tests combining a single node.""" source0 = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping["bookings_source"] - filter0 = FilterElementsNode( + filter0 = FilterElementsNode.create( parent_node=source0, include_specs=InstanceSpecSet(measure_specs=(MeasureSpec(element_name="bookings"),)) ) source1 = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping["bookings_source"] - filter1 = FilterElementsNode( + filter1 = FilterElementsNode.create( parent_node=source1, include_specs=InstanceSpecSet( measure_specs=(MeasureSpec(element_name="booking_value"),), diff --git a/tests_metricflow/examples/test_node_sql.py b/tests_metricflow/examples/test_node_sql.py index d3c96fffc6..aa88d93868 100644 --- a/tests_metricflow/examples/test_node_sql.py +++ b/tests_metricflow/examples/test_node_sql.py @@ -51,7 +51,7 @@ def test_view_sql_generated_at_a_node( # Show SQL and spec set at a source node. bookings_source_data_set = to_data_set_converter.create_sql_source_data_set(bookings_semantic_model) - read_source_node = ReadSqlSourceNode(bookings_source_data_set) + read_source_node = ReadSqlSourceNode.create(bookings_source_data_set) conversion_result = to_sql_plan_converter.convert_to_sql_query_plan( sql_engine_type=sql_client.sql_engine_type, dataflow_plan_node=read_source_node, @@ -63,13 +63,13 @@ def test_view_sql_generated_at_a_node( logger.info(f"SQL generated at {read_source_node} is:\n\n{sql_at_read_node}") logger.info(f"Spec set at {read_source_node} is:\n\n{mf_pformat(spec_set_at_read_node)}") - metric_time_node = MetricTimeDimensionTransformNode( + metric_time_node = MetricTimeDimensionTransformNode.create( parent_node=read_source_node, aggregation_time_dimension_reference=TimeDimensionReference(element_name="ds"), ) # Show SQL and spec set at a filter node. - filter_elements_node = FilterElementsNode( + filter_elements_node = FilterElementsNode.create( parent_node=metric_time_node, include_specs=InstanceSpecSet( time_dimension_specs=( diff --git a/tests_metricflow/execution/noop_task.py b/tests_metricflow/execution/noop_task.py index 563d16feaa..8f65ec049a 100644 --- a/tests_metricflow/execution/noop_task.py +++ b/tests_metricflow/execution/noop_task.py @@ -2,13 +2,13 @@ import logging import time -from typing import Optional, Sequence +from dataclasses import dataclass +from typing import ClassVar, Sequence from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow.execution.execution_plan import ( ExecutionPlanTask, - SqlQuery, TaskExecutionError, TaskExecutionResult, ) @@ -16,21 +16,28 @@ logger = logging.getLogger(__name__) +@dataclass(frozen=True) class NoOpExecutionPlanTask(ExecutionPlanTask): - """A no-op task for testing executors.""" + """A no-op task for testing executors. - # Error to return if should_error is set. - EXAMPLE_ERROR = TaskExecutionError("Expected Error") + Attributes: + should_error: If true, test the error flow by intentionally returning an error in the results. + """ - def __init__(self, parent_tasks: Sequence[ExecutionPlanTask] = (), should_error: bool = False) -> None: - """Constructor. + EXAMPLE_ERROR: ClassVar[TaskExecutionError] = TaskExecutionError("Expected Error") - Args: - parent_tasks: Self-explanatory. - should_error: if true, return an error in the results. - """ - self._should_error = should_error - super().__init__(task_id=self.create_unique_id(), parent_nodes=list(parent_tasks)) + should_error: bool = False + + @staticmethod + def create( # noqa: D102 + parent_tasks: Sequence[ExecutionPlanTask] = (), + should_error: bool = False, + ) -> NoOpExecutionPlanTask: + return NoOpExecutionPlanTask( + parent_nodes=tuple(parent_tasks), + sql_query=None, + should_error=should_error, + ) @property def description(self) -> str: # noqa: D102 @@ -45,9 +52,5 @@ def execute(self) -> TaskExecutionResult: # noqa: D102 time.sleep(0.01) end_time = time.time() return TaskExecutionResult( - start_time=start_time, end_time=end_time, errors=(self.EXAMPLE_ERROR,) if self._should_error else () + start_time=start_time, end_time=end_time, errors=(self.EXAMPLE_ERROR,) if self.should_error else () ) - - @property - def sql_query(self) -> Optional[SqlQuery]: # noqa: D102 - return None diff --git a/tests_metricflow/execution/test_sequential_executor.py b/tests_metricflow/execution/test_sequential_executor.py index 84f29b6461..998f842da8 100644 --- a/tests_metricflow/execution/test_sequential_executor.py +++ b/tests_metricflow/execution/test_sequential_executor.py @@ -9,7 +9,7 @@ def test_single_task() -> None: """Tests running an execution plan with a single task.""" - task = NoOpExecutionPlanTask() + task = NoOpExecutionPlanTask.create() execution_plan = ExecutionPlan(leaf_tasks=[task], dag_id=DagId.from_str("plan0")) results = SequentialPlanExecutor().execute_plan(execution_plan) assert results.get_result(task.task_id) @@ -17,7 +17,7 @@ def test_single_task() -> None: def test_single_task_error() -> None: """Check that an error is properly returned in the results if a task errors out.""" - task = NoOpExecutionPlanTask(should_error=True) + task = NoOpExecutionPlanTask.create(should_error=True) execution_plan = ExecutionPlan(leaf_tasks=[task], dag_id=DagId.from_str("plan0")) executor = SequentialPlanExecutor() results = executor.execute_plan(execution_plan) @@ -27,9 +27,9 @@ def test_single_task_error() -> None: def test_task_with_parents() -> None: """Tests a plan with a task that has 2 direct parents.""" - parent_task1 = NoOpExecutionPlanTask() - parent_task2 = NoOpExecutionPlanTask() - leaf_task = NoOpExecutionPlanTask(parent_tasks=[parent_task1, parent_task2]) + parent_task1 = NoOpExecutionPlanTask.create() + parent_task2 = NoOpExecutionPlanTask.create() + leaf_task = NoOpExecutionPlanTask.create(parent_tasks=[parent_task1, parent_task2]) execution_plan = ExecutionPlan(leaf_tasks=[leaf_task], dag_id=DagId.from_str("plan0")) results = SequentialPlanExecutor().execute_plan(execution_plan) @@ -47,9 +47,9 @@ def test_task_with_parents() -> None: def test_parent_task_error() -> None: """Check that a child task is not run if a parent task fails.""" - parent_task1 = NoOpExecutionPlanTask(should_error=True) - parent_task2 = NoOpExecutionPlanTask() - leaf_task = NoOpExecutionPlanTask(parent_tasks=[parent_task1, parent_task2]) + parent_task1 = NoOpExecutionPlanTask.create(should_error=True) + parent_task2 = NoOpExecutionPlanTask.create() + leaf_task = NoOpExecutionPlanTask.create(parent_tasks=[parent_task1, parent_task2]) execution_plan = ExecutionPlan(leaf_tasks=[leaf_task], dag_id=DagId.from_str("plan0")) executor = SequentialPlanExecutor() diff --git a/tests_metricflow/execution/test_tasks.py b/tests_metricflow/execution/test_tasks.py index 17fe99f2ef..6e24ec85de 100644 --- a/tests_metricflow/execution/test_tasks.py +++ b/tests_metricflow/execution/test_tasks.py @@ -10,6 +10,7 @@ ExecutionPlan, SelectSqlQueryToDataTableTask, SelectSqlQueryToTableTask, + SqlQuery, ) from metricflow.execution.executor import SequentialPlanExecutor from metricflow.protocols.sql_client import SqlClient, SqlEngine @@ -18,7 +19,7 @@ def test_read_sql_task(sql_client: SqlClient) -> None: # noqa: D103 - task = SelectSqlQueryToDataTableTask(sql_client, "SELECT 1 AS foo", SqlBindParameters()) + task = SelectSqlQueryToDataTableTask.create(sql_client, SqlQuery("SELECT 1 AS foo", SqlBindParameters())) execution_plan = ExecutionPlan(leaf_tasks=[task], dag_id=DagId.from_str("plan0")) results = SequentialPlanExecutor().execute_plan(execution_plan) @@ -41,10 +42,12 @@ def test_write_table_task( # noqa: D103 mf_test_configuration: MetricFlowTestConfiguration, sql_client: SqlClient ) -> None: # noqa: D103 output_table = SqlTable(schema_name=mf_test_configuration.mf_system_schema, table_name=f"test_table_{random_id()}") - task = SelectSqlQueryToTableTask( + task = SelectSqlQueryToTableTask.create( sql_client=sql_client, - sql_query=f"CREATE TABLE {output_table.sql} AS SELECT 1 AS foo", - bind_parameters=SqlBindParameters(), + sql_query=SqlQuery( + sql_query=f"CREATE TABLE {output_table.sql} AS SELECT 1 AS foo", + bind_parameters=SqlBindParameters(), + ), output_table=output_table, ) execution_plan = ExecutionPlan(leaf_tasks=[task], dag_id=DagId.from_str("plan0")) diff --git a/tests_metricflow/fixtures/manifest_fixtures.py b/tests_metricflow/fixtures/manifest_fixtures.py index 6ab1061a42..f34be3fc83 100644 --- a/tests_metricflow/fixtures/manifest_fixtures.py +++ b/tests_metricflow/fixtures/manifest_fixtures.py @@ -221,7 +221,7 @@ def _data_set_to_read_nodes( # Moved from model_fixtures.py. return_dict: OrderedDict[str, ReadSqlSourceNode] = OrderedDict() for semantic_model_name, data_set in data_sets.items(): - return_dict[semantic_model_name] = ReadSqlSourceNode(data_set) + return_dict[semantic_model_name] = ReadSqlSourceNode.create(data_set) logger.debug( f"For semantic model {semantic_model_name}, creating node_id {return_dict[semantic_model_name].node_id}" ) diff --git a/tests_metricflow/integration/test_configured_cases.py b/tests_metricflow/integration/test_configured_cases.py index f381fde20b..4f9c563d6b 100644 --- a/tests_metricflow/integration/test_configured_cases.py +++ b/tests_metricflow/integration/test_configured_cases.py @@ -97,8 +97,8 @@ def render_date_sub( granularity: TimeGranularity, ) -> str: """Renders a date subtract expression.""" - expr = SqlSubtractTimeIntervalExpression( - arg=SqlColumnReferenceExpression(SqlColumnReference(table_alias, column_alias)), + expr = SqlSubtractTimeIntervalExpression.create( + arg=SqlColumnReferenceExpression.create(SqlColumnReference(table_alias, column_alias)), count=count, granularity=granularity, ) @@ -106,10 +106,10 @@ def render_date_sub( def render_date_trunc(self, expr: str, granularity: TimeGranularity) -> str: """Return the DATE_TRUNC() call that can be used for converting the given expr to the granularity.""" - renderable_expr = SqlDateTruncExpression( + renderable_expr = SqlDateTruncExpression.create( time_granularity=granularity, - arg=SqlCastToTimestampExpression( - arg=SqlStringExpression( + arg=SqlCastToTimestampExpression.create( + arg=SqlStringExpression.create( sql_expr=expr, requires_parenthesis=False, ) @@ -119,10 +119,10 @@ def render_date_trunc(self, expr: str, granularity: TimeGranularity) -> str: def render_extract(self, expr: str, date_part: DatePart) -> str: """Return the EXTRACT call that can be used for converting the given expr to the date_part.""" - renderable_expr = SqlExtractExpression( + renderable_expr = SqlExtractExpression.create( date_part=date_part, - arg=SqlCastToTimestampExpression( - arg=SqlStringExpression( + arg=SqlCastToTimestampExpression.create( + arg=SqlStringExpression.create( sql_expr=expr, requires_parenthesis=False, ) @@ -142,8 +142,8 @@ def render_percentile_expr( ) ) - renderable_expr = SqlPercentileExpression( - order_by_arg=SqlStringExpression( + renderable_expr = SqlPercentileExpression.create( + order_by_arg=SqlStringExpression.create( sql_expr=expr, requires_parenthesis=False, ), @@ -191,7 +191,7 @@ def render_time_dimension_template( def generate_random_uuid(self) -> str: """Returns the generate random UUID SQL function.""" - expr = SqlGenerateUuidExpression() + expr = SqlGenerateUuidExpression.create() return self._sql_client.sql_query_plan_renderer.expr_renderer.render_sql_expr(expr).sql diff --git a/tests_metricflow/mf_logging/test_dag_to_text.py b/tests_metricflow/mf_logging/test_dag_to_text.py index 06086744c5..079d31bd90 100644 --- a/tests_metricflow/mf_logging/test_dag_to_text.py +++ b/tests_metricflow/mf_logging/test_dag_to_text.py @@ -33,15 +33,15 @@ def test_multithread_dag_to_text() -> None: dag_to_text_formatter = MetricFlowDagTextFormatter(max_width=1) dag = SqlQueryPlan( plan_id=DagId("plan"), - render_node=SqlSelectStatementNode( + render_node=SqlSelectStatementNode.create( description="test", select_columns=( SqlSelectColumn( - expr=SqlStringExpression("'foo'"), + expr=SqlStringExpression.create("'foo'"), column_alias="bar", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="schema", table_name="table")), + from_source=SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="schema", table_name="table")), from_source_alias="src", ), ) diff --git a/tests_metricflow/plan_conversion/dataflow_to_sql/test_metric_time_dimension_to_sql.py b/tests_metricflow/plan_conversion/dataflow_to_sql/test_metric_time_dimension_to_sql.py index 15e8da4a9f..3a4b460772 100644 --- a/tests_metricflow/plan_conversion/dataflow_to_sql/test_metric_time_dimension_to_sql.py +++ b/tests_metricflow/plan_conversion/dataflow_to_sql/test_metric_time_dimension_to_sql.py @@ -30,7 +30,7 @@ def test_metric_time_dimension_transform_node_using_primary_time( source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - metric_time_dimension_transform_node = MetricTimeDimensionTransformNode( + metric_time_dimension_transform_node = MetricTimeDimensionTransformNode.create( parent_node=source_node, aggregation_time_dimension_reference=TimeDimensionReference(element_name="ds") ) convert_and_check( @@ -54,7 +54,7 @@ def test_metric_time_dimension_transform_node_using_non_primary_time( source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - metric_time_dimension_transform_node = MetricTimeDimensionTransformNode( + metric_time_dimension_transform_node = MetricTimeDimensionTransformNode.create( parent_node=source_node, aggregation_time_dimension_reference=TimeDimensionReference(element_name="paid_at"), ) diff --git a/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py b/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py index 15ba3620d5..88fd000fda 100644 --- a/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py +++ b/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py @@ -153,7 +153,7 @@ def test_filter_node( source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - filter_node = FilterElementsNode( + filter_node = FilterElementsNode.create( parent_node=source_node, include_specs=InstanceSpecSet(measure_specs=(measure_spec,)) ) @@ -184,11 +184,11 @@ def test_filter_with_where_constraint_node( ] ds_spec = TimeDimensionSpec(element_name="ds", entity_links=(), time_granularity=TimeGranularity.DAY) - filter_node = FilterElementsNode( + filter_node = FilterElementsNode.create( parent_node=source_node, include_specs=InstanceSpecSet(measure_specs=(measure_spec,), time_dimension_specs=(ds_spec,)), ) # need to include ds_spec because where constraint operates on ds - where_constraint_node = WhereConstraintNode( + where_constraint_node = WhereConstraintNode.create( parent_node=filter_node, where_specs=( WhereFilterSpec( @@ -258,12 +258,12 @@ def test_measure_aggregation_node( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet(measure_specs=tuple(measure_specs)), ) - aggregated_measure_node = AggregateMeasuresNode( + aggregated_measure_node = AggregateMeasuresNode.create( parent_node=filtered_measure_node, metric_input_measure_specs=metric_input_measure_specs ) @@ -292,7 +292,7 @@ def test_single_join_node( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), @@ -307,7 +307,7 @@ def test_single_join_node( dimension_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "listings_latest" ] - filtered_dimension_node = FilterElementsNode( + filtered_dimension_node = FilterElementsNode.create( parent_node=dimension_source_node, include_specs=InstanceSpecSet( entity_specs=(entity_spec,), @@ -315,7 +315,7 @@ def test_single_join_node( ), ) - join_node = JoinOnEntitiesNode( + join_node = JoinOnEntitiesNode.create( left_node=filtered_measure_node, join_targets=[ JoinDescription( @@ -353,7 +353,7 @@ def test_multi_join_node( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet(measure_specs=(measure_spec,), entity_specs=(entity_spec,)), ) @@ -365,7 +365,7 @@ def test_multi_join_node( dimension_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "listings_latest" ] - filtered_dimension_node = FilterElementsNode( + filtered_dimension_node = FilterElementsNode.create( parent_node=dimension_source_node, include_specs=InstanceSpecSet( entity_specs=(entity_spec,), @@ -373,7 +373,7 @@ def test_multi_join_node( ), ) - join_node = JoinOnEntitiesNode( + join_node = JoinOnEntitiesNode.create( left_node=filtered_measure_node, join_targets=[ JoinDescription( @@ -419,7 +419,7 @@ def test_compute_metrics_node( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), @@ -434,7 +434,7 @@ def test_compute_metrics_node( dimension_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "listings_latest" ] - filtered_dimension_node = FilterElementsNode( + filtered_dimension_node = FilterElementsNode.create( parent_node=dimension_source_node, include_specs=InstanceSpecSet( entity_specs=(entity_spec,), @@ -442,7 +442,7 @@ def test_compute_metrics_node( ), ) - join_node = JoinOnEntitiesNode( + join_node = JoinOnEntitiesNode.create( left_node=filtered_measure_node, join_targets=[ JoinDescription( @@ -455,12 +455,12 @@ def test_compute_metrics_node( ], ) - aggregated_measure_node = AggregateMeasuresNode( + aggregated_measure_node = AggregateMeasuresNode.create( parent_node=join_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="bookings") - compute_metrics_node = ComputeMetricsNode( + compute_metrics_node = ComputeMetricsNode.create( parent_node=aggregated_measure_node, metric_specs=[metric_spec], aggregated_to_elements={entity_spec, dimension_spec}, @@ -492,7 +492,7 @@ def test_compute_metrics_node_simple_expr( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet(measure_specs=(measure_spec,), entity_specs=(entity_spec,)), ) @@ -504,7 +504,7 @@ def test_compute_metrics_node_simple_expr( dimension_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "listings_latest" ] - filtered_dimension_node = FilterElementsNode( + filtered_dimension_node = FilterElementsNode.create( parent_node=dimension_source_node, include_specs=InstanceSpecSet( entity_specs=(entity_spec,), @@ -512,7 +512,7 @@ def test_compute_metrics_node_simple_expr( ), ) - join_node = JoinOnEntitiesNode( + join_node = JoinOnEntitiesNode.create( left_node=filtered_measure_node, join_targets=[ JoinDescription( @@ -525,17 +525,17 @@ def test_compute_metrics_node_simple_expr( ], ) - aggregated_measures_node = AggregateMeasuresNode( + aggregated_measures_node = AggregateMeasuresNode.create( parent_node=join_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="booking_fees") - compute_metrics_node = ComputeMetricsNode( + compute_metrics_node = ComputeMetricsNode.create( parent_node=aggregated_measures_node, metric_specs=[metric_spec], aggregated_to_elements={entity_spec, dimension_spec}, ) - sink_node = WriteToResultDataTableNode(compute_metrics_node) + sink_node = WriteToResultDataTableNode.create(compute_metrics_node) dataflow_plan = DataflowPlan(sink_nodes=[sink_node], plan_id=DagId.from_str("plan0")) assert_plan_snapshot_text_equal( @@ -578,27 +578,27 @@ def test_join_to_time_spine_node_without_offset( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - metric_time_node = MetricTimeDimensionTransformNode( + metric_time_node = MetricTimeDimensionTransformNode.create( parent_node=measure_source_node, aggregation_time_dimension_reference=TimeDimensionReference(element_name="ds"), ) - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=metric_time_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), entity_specs=(entity_spec,), dimension_specs=(metric_time_spec,) ), ) - aggregated_measures_node = AggregateMeasuresNode( + aggregated_measures_node = AggregateMeasuresNode.create( parent_node=filtered_measure_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="booking_fees") - compute_metrics_node = ComputeMetricsNode( + compute_metrics_node = ComputeMetricsNode.create( parent_node=aggregated_measures_node, metric_specs=[metric_spec], aggregated_to_elements={entity_spec}, ) - join_to_time_spine_node = JoinToTimeSpineNode( + join_to_time_spine_node = JoinToTimeSpineNode.create( parent_node=compute_metrics_node, requested_agg_time_dimension_specs=[MTD_SPEC_DAY], use_custom_agg_time_dimension=False, @@ -608,7 +608,7 @@ def test_join_to_time_spine_node_without_offset( join_type=SqlJoinType.INNER, ) - sink_node = WriteToResultDataTableNode(join_to_time_spine_node) + sink_node = WriteToResultDataTableNode.create(join_to_time_spine_node) dataflow_plan = DataflowPlan(sink_nodes=[sink_node], plan_id=DagId.from_str("plan0")) assert_plan_snapshot_text_equal( @@ -651,26 +651,26 @@ def test_join_to_time_spine_node_with_offset_window( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - metric_time_node = MetricTimeDimensionTransformNode( + metric_time_node = MetricTimeDimensionTransformNode.create( parent_node=measure_source_node, aggregation_time_dimension_reference=TimeDimensionReference(element_name="ds"), ) - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=metric_time_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), entity_specs=(entity_spec,), dimension_specs=(metric_time_spec,) ), ) - aggregated_measures_node = AggregateMeasuresNode( + aggregated_measures_node = AggregateMeasuresNode.create( parent_node=filtered_measure_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="booking_fees") - compute_metrics_node = ComputeMetricsNode( + compute_metrics_node = ComputeMetricsNode.create( parent_node=aggregated_measures_node, metric_specs=[metric_spec], aggregated_to_elements={entity_spec, metric_time_spec}, ) - join_to_time_spine_node = JoinToTimeSpineNode( + join_to_time_spine_node = JoinToTimeSpineNode.create( parent_node=compute_metrics_node, requested_agg_time_dimension_specs=[MTD_SPEC_DAY], use_custom_agg_time_dimension=False, @@ -681,7 +681,7 @@ def test_join_to_time_spine_node_with_offset_window( join_type=SqlJoinType.INNER, ) - sink_node = WriteToResultDataTableNode(join_to_time_spine_node) + sink_node = WriteToResultDataTableNode.create(join_to_time_spine_node) dataflow_plan = DataflowPlan(sink_nodes=[sink_node], plan_id=DagId.from_str("plan0")) assert_plan_snapshot_text_equal( @@ -724,26 +724,26 @@ def test_join_to_time_spine_node_with_offset_to_grain( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - metric_time_node = MetricTimeDimensionTransformNode( + metric_time_node = MetricTimeDimensionTransformNode.create( parent_node=measure_source_node, aggregation_time_dimension_reference=TimeDimensionReference(element_name="ds"), ) - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=metric_time_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), entity_specs=(entity_spec,), dimension_specs=(metric_time_spec,) ), ) - aggregated_measures_node = AggregateMeasuresNode( + aggregated_measures_node = AggregateMeasuresNode.create( parent_node=filtered_measure_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="booking_fees") - compute_metrics_node = ComputeMetricsNode( + compute_metrics_node = ComputeMetricsNode.create( parent_node=aggregated_measures_node, metric_specs=[metric_spec], aggregated_to_elements={entity_spec, metric_time_spec}, ) - join_to_time_spine_node = JoinToTimeSpineNode( + join_to_time_spine_node = JoinToTimeSpineNode.create( parent_node=compute_metrics_node, requested_agg_time_dimension_specs=[MTD_SPEC_DAY], use_custom_agg_time_dimension=False, @@ -755,7 +755,7 @@ def test_join_to_time_spine_node_with_offset_to_grain( join_type=SqlJoinType.INNER, ) - sink_node = WriteToResultDataTableNode(join_to_time_spine_node) + sink_node = WriteToResultDataTableNode.create(join_to_time_spine_node) dataflow_plan = DataflowPlan(sink_nodes=[sink_node], plan_id=DagId.from_str("plan0")) assert_plan_snapshot_text_equal( @@ -803,7 +803,7 @@ def test_compute_metrics_node_ratio_from_single_semantic_model( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - filtered_measures_node = FilterElementsNode( + filtered_measures_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet(measure_specs=(numerator_spec, denominator_spec), entity_specs=(entity_spec,)), ) @@ -815,7 +815,7 @@ def test_compute_metrics_node_ratio_from_single_semantic_model( dimension_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "listings_latest" ] - filtered_dimension_node = FilterElementsNode( + filtered_dimension_node = FilterElementsNode.create( parent_node=dimension_source_node, include_specs=InstanceSpecSet( entity_specs=(entity_spec,), @@ -823,7 +823,7 @@ def test_compute_metrics_node_ratio_from_single_semantic_model( ), ) - join_node = JoinOnEntitiesNode( + join_node = JoinOnEntitiesNode.create( left_node=filtered_measures_node, join_targets=[ JoinDescription( @@ -836,11 +836,11 @@ def test_compute_metrics_node_ratio_from_single_semantic_model( ], ) - aggregated_measures_node = AggregateMeasuresNode( + aggregated_measures_node = AggregateMeasuresNode.create( parent_node=join_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="bookings_per_booker") - compute_metrics_node = ComputeMetricsNode( + compute_metrics_node = ComputeMetricsNode.create( parent_node=aggregated_measures_node, metric_specs=[metric_spec], aggregated_to_elements={entity_spec, dimension_spec}, @@ -882,7 +882,7 @@ def test_order_by_node( "bookings_source" ] - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), @@ -891,18 +891,18 @@ def test_order_by_node( ), ) - aggregated_measure_node = AggregateMeasuresNode( + aggregated_measure_node = AggregateMeasuresNode.create( parent_node=filtered_measure_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="bookings") - compute_metrics_node = ComputeMetricsNode( + compute_metrics_node = ComputeMetricsNode.create( parent_node=aggregated_measure_node, metric_specs=[metric_spec], aggregated_to_elements={dimension_spec, time_dimension_spec}, ) - order_by_node = OrderByLimitNode( + order_by_node = OrderByLimitNode.create( order_by_specs=[ OrderBySpec( instance_spec=time_dimension_spec, @@ -940,7 +940,7 @@ def test_semi_additive_join_node( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "accounts_source" ] - semi_additive_join_node = SemiAdditiveJoinNode( + semi_additive_join_node = SemiAdditiveJoinNode.create( parent_node=measure_source_node, entity_specs=tuple(), time_dimension_spec=time_dimension_spec, @@ -974,7 +974,7 @@ def test_semi_additive_join_node_with_queried_group_by( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "accounts_source" ] - semi_additive_join_node = SemiAdditiveJoinNode( + semi_additive_join_node = SemiAdditiveJoinNode.create( parent_node=measure_source_node, entity_specs=tuple(), time_dimension_spec=time_dimension_spec, @@ -1010,7 +1010,7 @@ def test_semi_additive_join_node_with_grouping( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "accounts_source" ] - semi_additive_join_node = SemiAdditiveJoinNode( + semi_additive_join_node = SemiAdditiveJoinNode.create( parent_node=measure_source_node, entity_specs=(entity_spec,), time_dimension_spec=time_dimension_spec, @@ -1037,7 +1037,7 @@ def test_constrain_time_range_node( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet( measure_specs=( @@ -1050,12 +1050,12 @@ def test_constrain_time_range_node( ), ), ) - metric_time_node = MetricTimeDimensionTransformNode( + metric_time_node = MetricTimeDimensionTransformNode.create( parent_node=filtered_measure_node, aggregation_time_dimension_reference=TimeDimensionReference(element_name="ds"), ) - constrain_time_node = ConstrainTimeRangeNode( + constrain_time_node = ConstrainTimeRangeNode.create( parent_node=metric_time_node, time_range_constraint=TimeRangeConstraint( start_time=as_datetime("2020-01-01"), @@ -1136,29 +1136,29 @@ def test_combine_output_node( # Build compute measures node measure_specs: List[MeasureSpec] = [sum_spec] - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet(measure_specs=tuple(measure_specs), dimension_specs=(dimension_spec,)), ) - aggregated_measure_node = AggregateMeasuresNode( + aggregated_measure_node = AggregateMeasuresNode.create( parent_node=filtered_measure_node, metric_input_measure_specs=tuple(MetricInputMeasureSpec(measure_spec=x) for x in measure_specs), ) # Build agg measures node measure_specs_2 = [sum_boolean_spec, count_distinct_spec] - filtered_measure_node_2 = FilterElementsNode( + filtered_measure_node_2 = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet(measure_specs=tuple(measure_specs_2), dimension_specs=(dimension_spec,)), ) - aggregated_measure_node_2 = AggregateMeasuresNode( + aggregated_measure_node_2 = AggregateMeasuresNode.create( parent_node=filtered_measure_node_2, metric_input_measure_specs=tuple( MetricInputMeasureSpec(measure_spec=x, fill_nulls_with=1) for x in measure_specs_2 ), ) - combine_output_node = CombineAggregatedOutputsNode([aggregated_measure_node, aggregated_measure_node_2]) + combine_output_node = CombineAggregatedOutputsNode.create([aggregated_measure_node, aggregated_measure_node_2]) convert_and_check( request=request, mf_test_configuration=mf_test_configuration, diff --git a/tests_metricflow/sql/optimizer/test_column_pruner.py b/tests_metricflow/sql/optimizer/test_column_pruner.py index 443c58fc80..a73feb22a5 100644 --- a/tests_metricflow/sql/optimizer/test_column_pruner.py +++ b/tests_metricflow/sql/optimizer/test_column_pruner.py @@ -70,51 +70,51 @@ def base_select_statement() -> SqlSelectStatementNode: ON from_source.join_col = joined_source.join_col """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="col0") ), column_alias="from_source_col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="col1") ), column_alias="from_source_col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="join_col") ), column_alias="from_source_join_col", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="col0") ), column_alias="joined_source_col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="col1") ), column_alias="joined_source_col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="join_col") ), column_alias="joined_source_join_col", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="from_source", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="from_source_table", column_name="col0", @@ -123,7 +123,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="from_source_table", column_name="col1", @@ -132,7 +132,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="from_source_table", column_name="join_col", @@ -141,17 +141,19 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="join_col", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="from_source_table")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="from_source_table") + ), from_source_alias="from_source_table", ), from_source_alias="from_source", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="joined_source", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="joined_source_table", column_name="col0", @@ -160,7 +162,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="joined_source_table", column_name="col1", @@ -169,7 +171,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="joined_source_table", column_name="join_col", @@ -178,18 +180,18 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="join_col", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="joined_source_table") ), from_source_alias="joined_source_table", ), right_source_alias="joined_source", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="join_col") ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="join_col") ), ), @@ -228,29 +230,29 @@ def test_prune_from_source( base_select_statement: SqlSelectStatementNode, ) -> None: """Tests a case where columns should be pruned from the FROM clause.""" - select_statement_with_some_from_source_column_removed = SqlSelectStatementNode( + select_statement_with_some_from_source_column_removed = SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="col0") ), column_alias="from_source_col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="col0") ), column_alias="joined_source_col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="col1") ), column_alias="joined_source_col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="join_col") ), column_alias="joined_source_join_col", @@ -258,7 +260,7 @@ def test_prune_from_source( ), from_source=base_select_statement.from_source, from_source_alias=base_select_statement.from_source_alias, - joins_descs=base_select_statement.join_descs, + join_descs=base_select_statement.join_descs, group_bys=base_select_statement.group_bys, order_bys=base_select_statement.order_bys, where=base_select_statement.where, @@ -287,29 +289,29 @@ def test_prune_joined_source( base_select_statement: SqlSelectStatementNode, ) -> None: """Tests a case where columns should be pruned from the JOIN clause.""" - select_statement_with_some_joined_source_column_removed = SqlSelectStatementNode( + select_statement_with_some_joined_source_column_removed = SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="col0") ), column_alias="from_source_col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="col1") ), column_alias="from_source_col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="join_col") ), column_alias="from_source_join_col", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="col0") ), column_alias="joined_source_col0", @@ -317,7 +319,7 @@ def test_prune_joined_source( ), from_source=base_select_statement.from_source, from_source_alias=base_select_statement.from_source_alias, - joins_descs=base_select_statement.join_descs, + join_descs=base_select_statement.join_descs, group_bys=base_select_statement.group_bys, order_bys=base_select_statement.order_bys, where=base_select_statement.where, @@ -346,11 +348,11 @@ def test_dont_prune_if_in_where( base_select_statement: SqlSelectStatementNode, ) -> None: """Tests that columns aren't pruned from parent sources if columns are used in a where.""" - select_statement_with_other_exprs = SqlSelectStatementNode( + select_statement_with_other_exprs = SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="col0") ), column_alias="from_source_col0", @@ -358,9 +360,11 @@ def test_dont_prune_if_in_where( ), from_source=base_select_statement.from_source, from_source_alias=base_select_statement.from_source_alias, - joins_descs=base_select_statement.join_descs, - where=SqlIsNullExpression( - SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="from_source", column_name="col1")) + join_descs=base_select_statement.join_descs, + where=SqlIsNullExpression.create( + SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="from_source", column_name="col1") + ) ), group_bys=base_select_statement.group_bys, order_bys=base_select_statement.order_bys, @@ -389,17 +393,17 @@ def test_dont_prune_with_str_expr( base_select_statement: SqlSelectStatementNode, ) -> None: """Tests that columns aren't pruned from parent sources if there's a string expression in the select.""" - select_statement_with_other_exprs = SqlSelectStatementNode( + select_statement_with_other_exprs = SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlStringExpression("from_source.col0", requires_parenthesis=False), + expr=SqlStringExpression.create("from_source.col0", requires_parenthesis=False), column_alias="some_string_expr", ), ), from_source=base_select_statement.from_source, from_source_alias=base_select_statement.from_source_alias, - joins_descs=base_select_statement.join_descs, + join_descs=base_select_statement.join_descs, where=base_select_statement.where, group_bys=base_select_statement.group_bys, order_bys=base_select_statement.order_bys, @@ -428,17 +432,17 @@ def test_prune_with_str_expr( base_select_statement: SqlSelectStatementNode, ) -> None: """Tests that columns are from parent sources if there's a string expression in the select with known cols.""" - select_statement_with_other_exprs = SqlSelectStatementNode( + select_statement_with_other_exprs = SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlStringExpression("from_source.col0", requires_parenthesis=False, used_columns=("col0",)), + expr=SqlStringExpression.create("from_source.col0", requires_parenthesis=False, used_columns=("col0",)), column_alias="some_string_expr", ), ), from_source=base_select_statement.from_source, from_source_alias=base_select_statement.from_source_alias, - joins_descs=base_select_statement.join_descs, + join_descs=base_select_statement.join_descs, where=base_select_statement.where, group_bys=base_select_statement.group_bys, order_bys=base_select_statement.order_bys, @@ -486,19 +490,19 @@ def string_select_statement() -> SqlSelectStatementNode: ON from_source.join_col = joined_source.join_col """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlStringExpression(sql_expr="col0", used_columns=("col0",)), + expr=SqlStringExpression.create(sql_expr="col0", used_columns=("col0",)), column_alias="from_source_col0", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="from_source", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="from_source_table", column_name="col0", @@ -507,7 +511,7 @@ def string_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="from_source_table", column_name="col1", @@ -516,7 +520,7 @@ def string_select_statement() -> SqlSelectStatementNode: column_alias="col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="from_source_table", column_name="join_col", @@ -525,17 +529,19 @@ def string_select_statement() -> SqlSelectStatementNode: column_alias="join_col", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="from_source_table")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="from_source_table") + ), from_source_alias="from_source_table", ), from_source_alias="from_source", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="joined_source", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="joined_source_table", column_name="col2", @@ -544,7 +550,7 @@ def string_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="joined_source_table", column_name="col3", @@ -553,7 +559,7 @@ def string_select_statement() -> SqlSelectStatementNode: column_alias="col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="joined_source_table", column_name="join_col", @@ -562,18 +568,18 @@ def string_select_statement() -> SqlSelectStatementNode: column_alias="join_col", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="joined_source_table") ), from_source_alias="joined_source_table", ), right_source_alias="joined_source", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="join_col") ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="join_col") ), ), @@ -631,19 +637,19 @@ def grandparent_pruning_select_statement() -> SqlSelectStatementNode: ) src1 ) src2 """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="src2", select_columns=( SqlSelectColumn( - expr=SqlStringExpression(sql_expr="col0"), + expr=SqlStringExpression.create(sql_expr="col0"), column_alias="col0", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src1", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src1", column_name="col0", @@ -652,7 +658,7 @@ def grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src1", column_name="col1", @@ -661,11 +667,11 @@ def grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="col1", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="col0", @@ -674,7 +680,7 @@ def grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="col1", @@ -683,7 +689,7 @@ def grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="col2", @@ -692,7 +698,7 @@ def grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="col2", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="src0")), + from_source=SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="demo", table_name="src0")), from_source_alias="src0", ), from_source_alias="src1", @@ -751,23 +757,25 @@ def join_grandparent_pruning_select_statement() -> SqlSelectStatementNode: ON src3.join_col = src4.join_col """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="4", select_columns=( SqlSelectColumn( - expr=SqlStringExpression(sql_expr="col0"), + expr=SqlStringExpression.create(sql_expr="col0"), column_alias="col0", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="from_source_table")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="from_source_table") + ), from_source_alias="src3", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="src1", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src1", column_name="col0", @@ -776,7 +784,7 @@ def join_grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src1", column_name="join_col", @@ -785,11 +793,11 @@ def join_grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="join_col", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="col0", @@ -798,7 +806,7 @@ def join_grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="col1", @@ -807,7 +815,7 @@ def join_grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="join_col", @@ -816,18 +824,20 @@ def join_grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="join_col", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="src0")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="src0") + ), from_source_alias="src0", ), from_source_alias="src1", ), right_source_alias="src4", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src3", column_name="join_col") ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src4", column_name="join_col") ), ), @@ -866,33 +876,35 @@ def test_prune_distinct_select( column_pruner: SqlColumnPrunerOptimizer, ) -> None: """Test that distinct select node shouldn't be pruned.""" - select_node = SqlSelectStatementNode( + select_node = SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), column_alias="booking_value", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="test1", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), column_alias="booking_value", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="bookings") ), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="a", distinct=True, ), diff --git a/tests_metricflow/sql/optimizer/test_rewriting_sub_query_reducer.py b/tests_metricflow/sql/optimizer/test_rewriting_sub_query_reducer.py index 87f5fd10a8..1536f30e66 100644 --- a/tests_metricflow/sql/optimizer/test_rewriting_sub_query_reducer.py +++ b/tests_metricflow/sql/optimizer/test_rewriting_sub_query_reducer.py @@ -54,14 +54,14 @@ def base_select_statement() -> SqlSelectStatementNode: GROUP BY src2.ds ORDER BY src2.ds """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="src3", select_columns=( SqlSelectColumn( - expr=SqlAggregateFunctionExpression( + expr=SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src2", column_name="bookings") ) ], @@ -69,29 +69,33 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="src2", column_name="ds")), + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="src2", column_name="ds") + ), column_alias="ds", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src2", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src1", column_name="bookings") ), column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="src1", column_name="ds")), + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="src1", column_name="ds") + ), column_alias="ds", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src1", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="bookings", @@ -100,7 +104,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="ds", @@ -109,27 +113,29 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="ds", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="src0", limit=2, ), from_source_alias="src1", - where=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + where=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="src1", column_name="ds", ) ), comparison=SqlComparison.GREATER_THAN_OR_EQUALS, - right_expr=SqlStringLiteralExpression("2020-01-01"), + right_expr=SqlStringLiteralExpression.create("2020-01-01"), ), limit=1, ), from_source_alias="src2", group_bys=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="src2", column_name="ds", @@ -138,19 +144,19 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="ds", ), ), - where=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + where=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="src2", column_name="ds", ) ), comparison=SqlComparison.LESS_THAN_OR_EQUALS, - right_expr=SqlStringLiteralExpression("2020-01-05"), + right_expr=SqlStringLiteralExpression.create("2020-01-05"), ), order_bys=( SqlOrderByDescription( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="src2", column_name="ds", @@ -210,14 +216,14 @@ def join_select_statement() -> SqlSelectStatementNode: ON bookings_src.listing = listings_src.listing GROUP BY bookings_src.ds """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="query", select_columns=( SqlSelectColumn( - expr=SqlAggregateFunctionExpression( + expr=SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="bookings_src", column_name="bookings") ) ], @@ -225,72 +231,74 @@ def join_select_statement() -> SqlSelectStatementNode: column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="listings_src", column_name="country_latest") ), column_alias="listing__country_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="bookings_src", column_name="ds") ), column_alias="ds", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="bookings_src", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="fct_bookings_src", column_name="booking") ), column_alias="bookings", ), SqlSelectColumn( - expr=SqlStringExpression(sql_expr="1", requires_parenthesis=False, used_columns=()), + expr=SqlStringExpression.create(sql_expr="1", requires_parenthesis=False, used_columns=()), column_alias="ds", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="fct_bookings_src", column_name="listing_id") ), column_alias="listing", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="fct_bookings_src", ), from_source_alias="bookings_src", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="listings_src", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="dim_listings_src", column_name="country") ), column_alias="country_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="dim_listings_src", column_name="listing_id") ), column_alias="listing", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="dim_listings") ), from_source_alias="dim_listings_src", ), right_source_alias="listings_src", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias="bookings_src", column_name="listing"), ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias="listings_src", column_name="listing"), ), ), @@ -299,7 +307,7 @@ def join_select_statement() -> SqlSelectStatementNode: ), group_bys=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="bookings_src", column_name="ds", @@ -308,15 +316,15 @@ def join_select_statement() -> SqlSelectStatementNode: column_alias="ds", ), ), - where=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + where=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="bookings_src", column_name="ds", ) ), comparison=SqlComparison.LESS_THAN_OR_EQUALS, - right_expr=SqlStringLiteralExpression("2020-01-05"), + right_expr=SqlStringLiteralExpression.create("2020-01-05"), ), ) @@ -369,14 +377,14 @@ def colliding_select_statement() -> SqlSelectStatementNode: ON bookings_src.listing = listings_src.listing GROUP BY bookings_src.ds """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="query", select_columns=( SqlSelectColumn( - expr=SqlAggregateFunctionExpression( + expr=SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="bookings_src", column_name="bookings") ) ], @@ -384,74 +392,76 @@ def colliding_select_statement() -> SqlSelectStatementNode: column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="listings_src", column_name="listing__country_latest") ), column_alias="listing__country_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="bookings_src", column_name="ds") ), column_alias="ds", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="bookings_src", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="colliding_alias", column_name="booking") ), column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="colliding_alias", column_name="ds") ), column_alias="ds", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="colliding_alias", column_name="listing_id") ), column_alias="listing", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="colliding_alias", ), from_source_alias="bookings_src", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="listings_src", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="colliding_alias", column_name="country") ), column_alias="country_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="colliding_alias", column_name="listing_id") ), column_alias="listing", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="dim_listings") ), from_source_alias="colliding_alias", ), right_source_alias="listings_src", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias="bookings_src", column_name="listing"), ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias="listings_src", column_name="listing"), ), ), @@ -460,7 +470,7 @@ def colliding_select_statement() -> SqlSelectStatementNode: ), group_bys=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="bookings_src", column_name="ds", @@ -469,15 +479,15 @@ def colliding_select_statement() -> SqlSelectStatementNode: column_alias="ds", ), ), - where=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + where=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="bookings_src", column_name="ds", ) ), comparison=SqlComparison.LESS_THAN_OR_EQUALS, - right_expr=SqlStringLiteralExpression("2020-01-05"), + right_expr=SqlStringLiteralExpression.create("2020-01-05"), ), ) @@ -538,14 +548,14 @@ def reduce_all_join_select_statement() -> SqlSelectStatementNode: ON listing_src1.listing = listings_src2.listing GROUP BY bookings_src.ds, listings_src1.country_latest, listings_src2.capacity_latest """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="query", select_columns=( SqlSelectColumn( - expr=SqlAggregateFunctionExpression( + expr=SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="bookings_src", column_name="bookings") ) ], @@ -553,114 +563,116 @@ def reduce_all_join_select_statement() -> SqlSelectStatementNode: column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="listings_src1", column_name="country_latest") ), column_alias="listing__country_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="listings_src2", column_name="capacity_latest") ), column_alias="listing__capacity_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="bookings_src", column_name="ds") ), column_alias="ds", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="bookings_src", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="fct_bookings_src", column_name="booking") ), column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="fct_bookings_src", column_name="ds") ), column_alias="ds", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="fct_bookings_src", column_name="listing_id") ), column_alias="listing", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="fct_bookings_src", ), from_source_alias="bookings_src", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="listings_src1", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="dim_listings_src1", column_name="country") ), column_alias="country_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="dim_listings_src1", column_name="listing_id") ), column_alias="listing", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="dim_listings") ), from_source_alias="dim_listings_src1", ), right_source_alias="listings_src1", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias="bookings_src", column_name="listing"), ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias="listings_src1", column_name="listing"), ), ), join_type=SqlJoinType.LEFT_OUTER, ), SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="listings_src2", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="dim_listings_src2", column_name="capacity") ), column_alias="capacity_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="dim_listings_src2", column_name="listing_id") ), column_alias="listing", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="dim_listings") ), from_source_alias="dim_listings_src2", ), right_source_alias="listings_src2", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias="listings_src1", column_name="listing"), ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias="listings_src2", column_name="listing"), ), ), @@ -669,7 +681,7 @@ def reduce_all_join_select_statement() -> SqlSelectStatementNode: ), group_bys=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="bookings_src", column_name="ds", @@ -678,7 +690,7 @@ def reduce_all_join_select_statement() -> SqlSelectStatementNode: column_alias="ds", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="listings_src1", column_name="country_latest", @@ -687,7 +699,7 @@ def reduce_all_join_select_statement() -> SqlSelectStatementNode: column_alias="listing__country_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="listings_src2", column_name="capacity_latest", @@ -748,30 +760,30 @@ def reducing_join_statement() -> SqlSelectStatementNode: FROM demo.fct_listings src4 ) src3 """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="query", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src2", column_name="bookings") ), column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src3", column_name="listings") ), column_alias="listings", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src2", select_columns=( SqlSelectColumn( - expr=SqlAggregateFunctionExpression( + expr=SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src1", column_name="bookings") ) ], @@ -779,30 +791,32 @@ def reducing_join_statement() -> SqlSelectStatementNode: column_alias="bookings", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src1", select_columns=( SqlSelectColumn( - expr=SqlStringExpression(sql_expr="1", requires_parenthesis=False, used_columns=()), + expr=SqlStringExpression.create(sql_expr="1", requires_parenthesis=False, used_columns=()), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="src0", ), from_source_alias="src1", ), from_source_alias="src2", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="src4", select_columns=( SqlSelectColumn( - expr=SqlAggregateFunctionExpression( + expr=SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src4", column_name="listings") ) ], @@ -810,7 +824,7 @@ def reducing_join_statement() -> SqlSelectStatementNode: column_alias="listings", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="fct_listings") ), from_source_alias="src4", @@ -872,30 +886,30 @@ def reducing_join_left_node_statement() -> SqlSelectStatementNode: FROM demo.fct_listings src4 ) src3 """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="query", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src2", column_name="bookings") ), column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src3", column_name="listings") ), column_alias="listings", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src4", select_columns=( SqlSelectColumn( - expr=SqlAggregateFunctionExpression( + expr=SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src4", column_name="listings") ) ], @@ -903,20 +917,22 @@ def reducing_join_left_node_statement() -> SqlSelectStatementNode: column_alias="listings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_listings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_listings") + ), from_source_alias="src4", ), from_source_alias="src2", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="src2", select_columns=( SqlSelectColumn( - expr=SqlAggregateFunctionExpression( + expr=SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src1", column_name="bookings") ) ], @@ -924,15 +940,17 @@ def reducing_join_left_node_statement() -> SqlSelectStatementNode: column_alias="bookings", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src1", select_columns=( SqlSelectColumn( - expr=SqlStringExpression(sql_expr="1", requires_parenthesis=False, used_columns=()), + expr=SqlStringExpression.create( + sql_expr="1", requires_parenthesis=False, used_columns=() + ), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") ), from_source_alias="src0", @@ -975,33 +993,35 @@ def test_rewriting_distinct_select_node_is_not_reduced( mf_test_configuration: MetricFlowTestConfiguration, ) -> None: """Tests to ensure distinct select node doesn't get overwritten.""" - select_node = SqlSelectStatementNode( + select_node = SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), column_alias="booking_value", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="test1", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), column_alias="booking_value", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="bookings") ), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="a", distinct=True, ), diff --git a/tests_metricflow/sql/optimizer/test_sub_query_reducer.py b/tests_metricflow/sql/optimizer/test_sub_query_reducer.py index ce0255beac..dc2bc157ee 100644 --- a/tests_metricflow/sql/optimizer/test_sub_query_reducer.py +++ b/tests_metricflow/sql/optimizer/test_sub_query_reducer.py @@ -46,39 +46,43 @@ def base_select_statement() -> SqlSelectStatementNode: ) src2 ORDER BY src2.col0 """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="src3", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="src2", column_name="col0")), + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="src2", column_name="col0") + ), column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="src2", column_name="col1")), + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="src2", column_name="col1") + ), column_alias="col1", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src2", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src1", column_name="col0") ), column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src1", column_name="col1") ), column_alias="col1", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src1", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="col0", @@ -87,7 +91,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="col1", @@ -96,7 +100,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="col1", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="from_source_table") ), from_source_alias="src0", @@ -108,7 +112,7 @@ def base_select_statement() -> SqlSelectStatementNode: from_source_alias="src2", order_bys=( SqlOrderByDescription( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="src2", column_name="col0", @@ -163,47 +167,53 @@ def rewrite_order_by_statement() -> SqlSelectStatementNode: ORDER BY src2.col1 """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="src3", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="src2", column_name="col0")), + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="src2", column_name="col0") + ), column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="src2", column_name="col1")), + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="src2", column_name="col1") + ), column_alias="col1", ), ), from_source=( - SqlSelectStatementNode( + SqlSelectStatementNode.create( description="src2", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src0", column_name="col0") ), column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src1", column_name="col1") ), column_alias="col1", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="src0")), + from_source=SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="demo", table_name="src0")), from_source_alias="src0", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="src1")), + right_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="src1") + ), right_source_alias="src1", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src0", column_name="join_col") ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src1", column_name="join_col") ), ), @@ -215,7 +225,7 @@ def rewrite_order_by_statement() -> SqlSelectStatementNode: from_source_alias="src2", order_bys=( SqlOrderByDescription( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="src2", column_name="col1", @@ -255,33 +265,35 @@ def test_distinct_select_node_is_not_reduced( mf_test_configuration: MetricFlowTestConfiguration, ) -> None: """Tests to ensure distinct select node doesn't get overwritten.""" - select_node = SqlSelectStatementNode( + select_node = SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), column_alias="booking_value", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="test1", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), column_alias="booking_value", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="bookings") ), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="a", distinct=True, ), diff --git a/tests_metricflow/sql/optimizer/test_table_alias_simplifier.py b/tests_metricflow/sql/optimizer/test_table_alias_simplifier.py index 6dff3fcd3d..d603054f38 100644 --- a/tests_metricflow/sql/optimizer/test_table_alias_simplifier.py +++ b/tests_metricflow/sql/optimizer/test_table_alias_simplifier.py @@ -53,27 +53,27 @@ def base_select_statement() -> SqlSelectStatementNode: ON from_source.join_col = joined_source.join_col """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="col0") ), column_alias="from_source_col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="col0") ), column_alias="joined_source_col0", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="from_source", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="from_source_table", column_name="col0", @@ -82,7 +82,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="from_source_table", column_name="join_col", @@ -91,17 +91,19 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="join_col", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="from_source_table")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="from_source_table") + ), from_source_alias="from_source_table", ), from_source_alias="from_source", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="joined_source", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="joined_source_table", column_name="col0", @@ -110,7 +112,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="joined_source_table", column_name="join_col", @@ -119,18 +121,18 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="join_col", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="joined_source_table") ), from_source_alias="joined_source_table", ), right_source_alias="joined_source", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="join_col") ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="join_col") ), ), diff --git a/tests_metricflow/sql/test_engine_specific_rendering.py b/tests_metricflow/sql/test_engine_specific_rendering.py index 71b0383620..6ba8f8fc15 100644 --- a/tests_metricflow/sql/test_engine_specific_rendering.py +++ b/tests_metricflow/sql/test_engine_specific_rendering.py @@ -37,8 +37,8 @@ def test_cast_to_timestamp( """Tests rendering of the cast to timestamp expression in a query.""" select_columns = [ SqlSelectColumn( - expr=SqlCastToTimestampExpression( - arg=SqlStringLiteralExpression( + expr=SqlCastToTimestampExpression.create( + arg=SqlStringLiteralExpression.create( literal_value="2020-01-01", ) ), @@ -46,7 +46,7 @@ def test_cast_to_timestamp( ), ] - from_source = SqlTableFromClauseNode(sql_table=SqlTable(schema_name="foo", table_name="bar")) + from_source = SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="foo", table_name="bar")) from_source_alias = "a" joins_descs: List[SqlJoinDescription] = [] where = None @@ -56,12 +56,12 @@ def test_cast_to_timestamp( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="Test Cast to Timestamp Expression", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -80,17 +80,17 @@ def test_generate_uuid( """Tests rendering of the generate uuid expression in a query.""" select_columns = [ SqlSelectColumn( - expr=SqlGenerateUuidExpression(), + expr=SqlGenerateUuidExpression.create(), column_alias="uuid", ), ] - from_source = SqlTableFromClauseNode(sql_table=SqlTable(schema_name="foo", table_name="bar")) + from_source = SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="foo", table_name="bar")) from_source_alias = "a" assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="Test Generate UUID Expression", select_columns=tuple(select_columns), from_source=from_source, @@ -115,8 +115,8 @@ def test_continuous_percentile_expr( select_columns = [ SqlSelectColumn( - expr=SqlPercentileExpression( - order_by_arg=SqlColumnReferenceExpression(SqlColumnReference("a", "col0")), + expr=SqlPercentileExpression.create( + order_by_arg=SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0")), percentile_args=SqlPercentileExpressionArgument( percentile=0.5, function_type=SqlPercentileFunctionType.CONTINUOUS ), @@ -125,7 +125,7 @@ def test_continuous_percentile_expr( ), ] - from_source = SqlTableFromClauseNode(sql_table=SqlTable(schema_name="foo", table_name="bar")) + from_source = SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="foo", table_name="bar")) from_source_alias = "a" joins_descs: List[SqlJoinDescription] = [] where = None @@ -135,12 +135,12 @@ def test_continuous_percentile_expr( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="Test Continuous Percentile Expression", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -164,8 +164,8 @@ def test_discrete_percentile_expr( select_columns = [ SqlSelectColumn( - expr=SqlPercentileExpression( - order_by_arg=SqlColumnReferenceExpression(SqlColumnReference("a", "col0")), + expr=SqlPercentileExpression.create( + order_by_arg=SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0")), percentile_args=SqlPercentileExpressionArgument( percentile=0.5, function_type=SqlPercentileFunctionType.DISCRETE ), @@ -174,7 +174,7 @@ def test_discrete_percentile_expr( ), ] - from_source = SqlTableFromClauseNode(sql_table=SqlTable(schema_name="foo", table_name="bar")) + from_source = SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="foo", table_name="bar")) from_source_alias = "a" joins_descs: List[SqlJoinDescription] = [] where = None @@ -184,12 +184,12 @@ def test_discrete_percentile_expr( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="Test Discrete Percentile Expression", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -213,8 +213,8 @@ def test_approximate_continuous_percentile_expr( select_columns = [ SqlSelectColumn( - expr=SqlPercentileExpression( - order_by_arg=SqlColumnReferenceExpression(SqlColumnReference("a", "col0")), + expr=SqlPercentileExpression.create( + order_by_arg=SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0")), percentile_args=SqlPercentileExpressionArgument( percentile=0.5, function_type=SqlPercentileFunctionType.APPROXIMATE_CONTINUOUS ), @@ -223,7 +223,7 @@ def test_approximate_continuous_percentile_expr( ), ] - from_source = SqlTableFromClauseNode(sql_table=SqlTable(schema_name="foo", table_name="bar")) + from_source = SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="foo", table_name="bar")) from_source_alias = "a" joins_descs: List[SqlJoinDescription] = [] where = None @@ -233,12 +233,12 @@ def test_approximate_continuous_percentile_expr( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="Test Approximate Continuous Percentile Expression", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -262,8 +262,8 @@ def test_approximate_discrete_percentile_expr( select_columns = [ SqlSelectColumn( - expr=SqlPercentileExpression( - order_by_arg=SqlColumnReferenceExpression(SqlColumnReference("a", "col0")), + expr=SqlPercentileExpression.create( + order_by_arg=SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0")), percentile_args=SqlPercentileExpressionArgument( percentile=0.5, function_type=SqlPercentileFunctionType.APPROXIMATE_DISCRETE ), @@ -272,7 +272,7 @@ def test_approximate_discrete_percentile_expr( ), ] - from_source = SqlTableFromClauseNode(sql_table=SqlTable(schema_name="foo", table_name="bar")) + from_source = SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="foo", table_name="bar")) from_source_alias = "a" joins_descs: List[SqlJoinDescription] = [] where = None @@ -282,12 +282,12 @@ def test_approximate_discrete_percentile_expr( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="Test Approximate Discrete Percentile Expression", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), diff --git a/tests_metricflow/sql/test_sql_expr_render.py b/tests_metricflow/sql/test_sql_expr_render.py index 7e6ba4cb6e..05b4103ea2 100644 --- a/tests_metricflow/sql/test_sql_expr_render.py +++ b/tests_metricflow/sql/test_sql_expr_render.py @@ -45,14 +45,14 @@ def default_expr_renderer() -> DefaultSqlExpressionRenderer: # noqa: D103 def test_str_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 - actual = default_expr_renderer.render_sql_expr(SqlStringExpression("a + b")).sql + actual = default_expr_renderer.render_sql_expr(SqlStringExpression.create("a + b")).sql expected = "a + b" assert actual == expected def test_col_ref_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlColumnReferenceExpression(SqlColumnReference("my_table", "my_col")) + SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "my_col")) ).sql expected = "my_table.my_col" assert actual == expected @@ -60,10 +60,10 @@ def test_col_ref_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> No def test_comparison_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression(SqlColumnReference("my_table", "my_col")), + SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "my_col")), comparison=SqlComparison.EQUALS, - right_expr=SqlStringExpression("a + b"), + right_expr=SqlStringExpression.create("a + b"), ) ).sql assert actual == "my_table.my_col = (a + b)" @@ -71,10 +71,10 @@ def test_comparison_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> def test_require_parenthesis(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression(SqlColumnReference("a", "booking_value")), + SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create(SqlColumnReference("a", "booking_value")), comparison=SqlComparison.GREATER_THAN, - right_expr=SqlStringExpression("100", requires_parenthesis=False), + right_expr=SqlStringExpression.create("100", requires_parenthesis=False), ) ).sql @@ -83,11 +83,11 @@ def test_require_parenthesis(default_expr_renderer: DefaultSqlExpressionRenderer def test_function_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlAggregateFunctionExpression( + SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression(SqlColumnReference("my_table", "a")), - SqlColumnReferenceExpression(SqlColumnReference("my_table", "b")), + SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "a")), + SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "b")), ], ) ).sql @@ -97,11 +97,11 @@ def test_function_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> N def test_distinct_agg_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: """Distinct aggregation functions require the insertion of the DISTINCT keyword in the rendered function expr.""" actual = default_expr_renderer.render_sql_expr( - SqlAggregateFunctionExpression( + SqlAggregateFunctionExpression.create( sql_function=SqlFunction.COUNT_DISTINCT, sql_function_args=[ - SqlColumnReferenceExpression(SqlColumnReference("my_table", "a")), - SqlColumnReferenceExpression(SqlColumnReference("my_table", "b")), + SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "a")), + SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "b")), ], ) ).sql @@ -111,15 +111,15 @@ def test_distinct_agg_expr(default_expr_renderer: DefaultSqlExpressionRenderer) def test_nested_function_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlAggregateFunctionExpression( + SqlAggregateFunctionExpression.create( sql_function=SqlFunction.CONCAT, sql_function_args=[ - SqlColumnReferenceExpression(SqlColumnReference("my_table", "a")), - SqlAggregateFunctionExpression( + SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "a")), + SqlAggregateFunctionExpression.create( sql_function=SqlFunction.CONCAT, sql_function_args=[ - SqlColumnReferenceExpression(SqlColumnReference("my_table", "b")), - SqlColumnReferenceExpression(SqlColumnReference("my_table", "c")), + SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "b")), + SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "c")), ], ), ], @@ -129,17 +129,17 @@ def test_nested_function_expr(default_expr_renderer: DefaultSqlExpressionRendere def test_null_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 - actual = default_expr_renderer.render_sql_expr(SqlNullExpression()).sql + actual = default_expr_renderer.render_sql_expr(SqlNullExpression.create()).sql assert actual == "NULL" def test_and_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlLogicalExpression( + SqlLogicalExpression.create( operator=SqlLogicalOperator.AND, args=( - SqlStringExpression("1 < 2", requires_parenthesis=True), - SqlStringExpression("foo", requires_parenthesis=False), + SqlStringExpression.create("1 < 2", requires_parenthesis=True), + SqlStringExpression.create("foo", requires_parenthesis=False), ), ) ).sql @@ -155,12 +155,12 @@ def test_and_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: def test_long_and_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlLogicalExpression( + SqlLogicalExpression.create( operator=SqlLogicalOperator.AND, args=( - SqlStringExpression("some_long_expression1"), - SqlStringExpression("some_long_expression2"), - SqlStringExpression("some_long_expression3"), + SqlStringExpression.create("some_long_expression1"), + SqlStringExpression.create("some_long_expression2"), + SqlStringExpression.create("some_long_expression3"), ), ) ).sql @@ -181,56 +181,59 @@ def test_long_and_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> N def test_string_literal_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 - actual = default_expr_renderer.render_sql_expr(SqlStringLiteralExpression("foo")).sql + actual = default_expr_renderer.render_sql_expr(SqlStringLiteralExpression.create("foo")).sql assert actual == "'foo'" def test_is_null_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlIsNullExpression(SqlStringExpression("foo", requires_parenthesis=False)) + SqlIsNullExpression.create(SqlStringExpression.create("foo", requires_parenthesis=False)) ).sql assert actual == "foo IS NULL" def test_date_trunc_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlDateTruncExpression(time_granularity=TimeGranularity.MONTH, arg=SqlStringExpression("ds")) + SqlDateTruncExpression.create(time_granularity=TimeGranularity.MONTH, arg=SqlStringExpression.create("ds")) ).sql assert actual == "DATE_TRUNC('month', ds)" def test_extract_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlExtractExpression(date_part=DatePart.DOY, arg=SqlStringExpression("ds")) + SqlExtractExpression.create(date_part=DatePart.DOY, arg=SqlStringExpression.create("ds")) ).sql assert actual == "EXTRACT(doy FROM ds)" def test_ratio_computation_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlRatioComputationExpression( - numerator=SqlAggregateFunctionExpression( - SqlFunction.SUM, sql_function_args=[SqlStringExpression(sql_expr="1", requires_parenthesis=False)] + SqlRatioComputationExpression.create( + numerator=SqlAggregateFunctionExpression.create( + SqlFunction.SUM, + sql_function_args=[SqlStringExpression.create(sql_expr="1", requires_parenthesis=False)], + ), + denominator=SqlColumnReferenceExpression.create( + SqlColumnReference(column_name="divide_by_me", table_alias="a") ), - denominator=SqlColumnReferenceExpression(SqlColumnReference(column_name="divide_by_me", table_alias="a")), ), ).sql assert actual == "CAST(SUM(1) AS DOUBLE) / CAST(NULLIF(a.divide_by_me, 0) AS DOUBLE)" def test_expr_rewrite(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 - expr = SqlLogicalExpression( + expr = SqlLogicalExpression.create( operator=SqlLogicalOperator.AND, args=( - SqlColumnReferenceExpression(SqlColumnReference("a", "col0")), - SqlColumnReferenceExpression(SqlColumnReference("a", "col1")), + SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0")), + SqlColumnReferenceExpression.create(SqlColumnReference("a", "col1")), ), ) column_replacements = SqlColumnReplacements( { - SqlColumnReference("a", "col0"): SqlStringExpression("foo", requires_parenthesis=False), - SqlColumnReference("a", "col1"): SqlStringExpression("bar", requires_parenthesis=False), + SqlColumnReference("a", "col0"): SqlStringExpression.create("foo", requires_parenthesis=False), + SqlColumnReference("a", "col1"): SqlStringExpression.create("bar", requires_parenthesis=False), } ) expr_rewritten = expr.rewrite(column_replacements) @@ -239,15 +242,15 @@ def test_expr_rewrite(default_expr_renderer: DefaultSqlExpressionRenderer) -> No def test_between_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlBetweenExpression( - column_arg=SqlColumnReferenceExpression(SqlColumnReference("a", "col0")), - start_expr=SqlCastToTimestampExpression( - arg=SqlStringLiteralExpression( + SqlBetweenExpression.create( + column_arg=SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0")), + start_expr=SqlCastToTimestampExpression.create( + arg=SqlStringLiteralExpression.create( literal_value="2020-01-01", ) ), - end_expr=SqlCastToTimestampExpression( - arg=SqlStringLiteralExpression( + end_expr=SqlCastToTimestampExpression.create( + arg=SqlStringLiteralExpression.create( literal_value="2020-01-10", ) ), @@ -262,17 +265,17 @@ def test_window_function_expr( # noqa: D103 default_expr_renderer: DefaultSqlExpressionRenderer, ) -> None: partition_by_args = ( - SqlColumnReferenceExpression(SqlColumnReference("b", "col0")), - SqlColumnReferenceExpression(SqlColumnReference("b", "col1")), + SqlColumnReferenceExpression.create(SqlColumnReference("b", "col0")), + SqlColumnReferenceExpression.create(SqlColumnReference("b", "col1")), ) order_by_args = ( SqlWindowOrderByArgument( - expr=SqlColumnReferenceExpression(SqlColumnReference("a", "col0")), + expr=SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0")), descending=True, nulls_last=False, ), SqlWindowOrderByArgument( - expr=SqlColumnReferenceExpression(SqlColumnReference("b", "col0")), + expr=SqlColumnReferenceExpression.create(SqlColumnReference("b", "col0")), descending=False, nulls_last=True, ), @@ -284,9 +287,9 @@ def test_window_function_expr( # noqa: D103 rendered_sql_lines.append(f"-- Window function with {num_partition_by_args} PARTITION BY items(s)") rendered_sql_lines.append( default_expr_renderer.render_sql_expr( - SqlWindowFunctionExpression( + SqlWindowFunctionExpression.create( sql_function=SqlWindowFunction.FIRST_VALUE, - sql_function_args=[SqlColumnReferenceExpression(SqlColumnReference("a", "col0"))], + sql_function_args=[SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0"))], partition_by_args=partition_by_args[:num_partition_by_args], order_by_args=(), ) @@ -298,9 +301,9 @@ def test_window_function_expr( # noqa: D103 rendered_sql_lines.append(f"-- Window function with {num_order_by_args} ORDER BY items(s)") rendered_sql_lines.append( default_expr_renderer.render_sql_expr( - SqlWindowFunctionExpression( + SqlWindowFunctionExpression.create( sql_function=SqlWindowFunction.FIRST_VALUE, - sql_function_args=[SqlColumnReferenceExpression(SqlColumnReference("a", "col0"))], + sql_function_args=[SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0"))], partition_by_args=(), order_by_args=order_by_args[:num_order_by_args], ) @@ -311,9 +314,9 @@ def test_window_function_expr( # noqa: D103 rendered_sql_lines.append("-- Window function with PARTITION BY and ORDER BY items") rendered_sql_lines.append( default_expr_renderer.render_sql_expr( - SqlWindowFunctionExpression( + SqlWindowFunctionExpression.create( sql_function=SqlWindowFunction.FIRST_VALUE, - sql_function_args=[SqlColumnReferenceExpression(SqlColumnReference("a", "col0"))], + sql_function_args=[SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0"))], partition_by_args=partition_by_args, order_by_args=order_by_args, ) diff --git a/tests_metricflow/sql/test_sql_plan_render.py b/tests_metricflow/sql/test_sql_plan_render.py index f844cee00d..14dd164a89 100644 --- a/tests_metricflow/sql/test_sql_plan_render.py +++ b/tests_metricflow/sql/test_sql_plan_render.py @@ -42,14 +42,14 @@ def test_component_rendering( # Test single SELECT column select_columns = [ SqlSelectColumn( - expr=SqlAggregateFunctionExpression( - sql_function=SqlFunction.SUM, sql_function_args=[SqlStringExpression("1")] + expr=SqlAggregateFunctionExpression.create( + sql_function=SqlFunction.SUM, sql_function_args=[SqlStringExpression.create("1")] ), column_alias="bookings", ), ] - from_source = SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")) + from_source = SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")) from_source = from_source from_source_alias = "a" @@ -61,12 +61,12 @@ def test_component_rendering( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test0", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -79,11 +79,11 @@ def test_component_rendering( select_columns.extend( [ SqlSelectColumn( - expr=SqlColumnReferenceExpression(SqlColumnReference("b", "country")), + expr=SqlColumnReferenceExpression.create(SqlColumnReference("b", "country")), column_alias="user__country", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression(SqlColumnReference("c", "country")), + expr=SqlColumnReferenceExpression.create(SqlColumnReference("c", "country")), column_alias="listing__country", ), ] @@ -92,12 +92,12 @@ def test_component_rendering( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test1", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -109,12 +109,12 @@ def test_component_rendering( # Test single join joins_descs.append( SqlJoinDescription( - right_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="dim_users")), + right_source=SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="demo", table_name="dim_users")), right_source_alias="b", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression(SqlColumnReference("a", "user_id")), + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create(SqlColumnReference("a", "user_id")), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression(SqlColumnReference("b", "user_id")), + right_expr=SqlColumnReferenceExpression.create(SqlColumnReference("b", "user_id")), ), join_type=SqlJoinType.LEFT_OUTER, ) @@ -123,12 +123,12 @@ def test_component_rendering( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test2", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -140,12 +140,14 @@ def test_component_rendering( # Test multiple join joins_descs.append( SqlJoinDescription( - right_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="dim_listings")), + right_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="dim_listings") + ), right_source_alias="c", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression(SqlColumnReference("a", "user_id")), + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create(SqlColumnReference("a", "user_id")), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression(SqlColumnReference("c", "user_id")), + right_expr=SqlColumnReferenceExpression.create(SqlColumnReference("c", "user_id")), ), join_type=SqlJoinType.LEFT_OUTER, ) @@ -154,12 +156,12 @@ def test_component_rendering( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test3", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -171,7 +173,7 @@ def test_component_rendering( # Test single group by group_bys.append( SqlSelectColumn( - expr=SqlColumnReferenceExpression(SqlColumnReference("b", "country")), + expr=SqlColumnReferenceExpression.create(SqlColumnReference("b", "country")), column_alias="user__country", ), ) @@ -179,12 +181,12 @@ def test_component_rendering( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test4", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -196,7 +198,7 @@ def test_component_rendering( # Test multiple group bys group_bys.append( SqlSelectColumn( - expr=SqlColumnReferenceExpression(SqlColumnReference("c", "country")), + expr=SqlColumnReferenceExpression.create(SqlColumnReference("c", "country")), column_alias="listing__country", ), ) @@ -204,12 +206,12 @@ def test_component_rendering( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test5", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -228,24 +230,26 @@ def test_render_where( # noqa: D103 assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), column_alias="booking_value", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="a", - where=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + where=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), comparison=SqlComparison.GREATER_THAN, - right_expr=SqlStringExpression( + right_expr=SqlStringExpression.create( sql_expr="100", requires_parenthesis=False, used_columns=(), @@ -266,33 +270,35 @@ def test_render_order_by( # noqa: D103 assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), column_alias="booking_value", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="bookings") ), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="a", order_bys=( SqlOrderByDescription( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), desc=False, ), SqlOrderByDescription( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="bookings") ), desc=True, @@ -313,17 +319,19 @@ def test_render_limit( # noqa: D103 assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="bookings") ), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="a", limit=1, ), @@ -338,22 +346,24 @@ def test_render_create_table_as( # noqa: D103 mf_test_configuration: MetricFlowTestConfiguration, sql_client: SqlClient, ) -> None: - select_node = SqlSelectStatementNode( + select_node = SqlSelectStatementNode.create( description="select_0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="a", column_name="bookings")), + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="a", column_name="bookings") + ), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), from_source_alias="a", limit=1, ) assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlCreateTableAsNode( + sql_plan_node=SqlCreateTableAsNode.create( sql_table=SqlTable( schema_name="schema_name", table_name="table_name", @@ -367,7 +377,7 @@ def test_render_create_table_as( # noqa: D103 assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlCreateTableAsNode( + sql_plan_node=SqlCreateTableAsNode.create( sql_table=SqlTable( schema_name="schema_name", table_name="table_name", diff --git a/tests_metricflow/sql_clients/test_date_time_operations.py b/tests_metricflow/sql_clients/test_date_time_operations.py index a99b4e58bd..0370ade60e 100644 --- a/tests_metricflow/sql_clients/test_date_time_operations.py +++ b/tests_metricflow/sql_clients/test_date_time_operations.py @@ -44,8 +44,8 @@ def _extract_data_table_value(df: MetricFlowDataTable) -> Any: # type: ignore[m def _build_date_trunc_expression(date_string: str, time_granularity: TimeGranularity) -> SqlDateTruncExpression: - cast_expr = SqlCastToTimestampExpression(SqlStringLiteralExpression(literal_value=date_string)) - return SqlDateTruncExpression(time_granularity=time_granularity, arg=cast_expr) + cast_expr = SqlCastToTimestampExpression.create(SqlStringLiteralExpression.create(literal_value=date_string)) + return SqlDateTruncExpression.create(time_granularity=time_granularity, arg=cast_expr) def test_date_trunc_to_year(sql_client: SqlClient) -> None: @@ -118,8 +118,8 @@ def test_date_trunc_to_week(sql_client: SqlClient, input: str, expected: datetim def _build_extract_expression(date_string: str, date_part: DatePart) -> SqlExtractExpression: - cast_expr = SqlCastToTimestampExpression(SqlStringLiteralExpression(literal_value=date_string)) - return SqlExtractExpression(date_part=date_part, arg=cast_expr) + cast_expr = SqlCastToTimestampExpression.create(SqlStringLiteralExpression.create(literal_value=date_string)) + return SqlExtractExpression.create(date_part=date_part, arg=cast_expr) def test_date_part_year(sql_client: SqlClient) -> None: