From 5986a03b0eeae8bc85445b956cc82903077fd590 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 20 Nov 2024 09:20:12 -0800 Subject: [PATCH] Simplify dataflow to SQL logic for JoinOverTimeRangeNode There should be no functional changes in this commit, only cleanup and readability improvements. Mostly involves moving complex logic to helper functions. --- metricflow/dataflow/nodes/join_over_time.py | 6 +- metricflow/dataset/sql_dataset.py | 23 +++++ metricflow/plan_conversion/dataflow_to_sql.py | 84 ++++++++----------- .../plan_conversion/instance_converters.py | 1 + .../plan_conversion/sql_join_builder.py | 19 +---- 5 files changed, 63 insertions(+), 70 deletions(-) diff --git a/metricflow/dataflow/nodes/join_over_time.py b/metricflow/dataflow/nodes/join_over_time.py index c766bb7dc..82efccf88 100644 --- a/metricflow/dataflow/nodes/join_over_time.py +++ b/metricflow/dataflow/nodes/join_over_time.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Sequence +from typing import Optional, Sequence, Tuple from dbt_semantic_interfaces.protocols import MetricTimeWindow from dbt_semantic_interfaces.type_enums import TimeGranularity @@ -26,7 +26,7 @@ class JoinOverTimeRangeNode(DataflowPlanNode): time_range_constraint: Time range to aggregate over. """ - queried_agg_time_dimension_specs: Sequence[TimeDimensionSpec] + queried_agg_time_dimension_specs: Tuple[TimeDimensionSpec, ...] window: Optional[MetricTimeWindow] grain_to_date: Optional[TimeGranularity] time_range_constraint: Optional[TimeRangeConstraint] @@ -38,7 +38,7 @@ def __post_init__(self) -> None: # noqa: D105 @staticmethod def create( # noqa: D102 parent_node: DataflowPlanNode, - queried_agg_time_dimension_specs: Sequence[TimeDimensionSpec], + queried_agg_time_dimension_specs: Tuple[TimeDimensionSpec, ...], window: Optional[MetricTimeWindow] = None, grain_to_date: Optional[TimeGranularity] = None, time_range_constraint: Optional[TimeRangeConstraint] = None, diff --git a/metricflow/dataset/sql_dataset.py b/metricflow/dataset/sql_dataset.py index 5e18bf6d3..4bb530e71 100644 --- a/metricflow/dataset/sql_dataset.py +++ b/metricflow/dataset/sql_dataset.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import List, Optional, Sequence, Tuple from dbt_semantic_interfaces.references import SemanticModelReference @@ -160,3 +161,25 @@ def column_association_for_time_dimension(self, time_dimension_spec: TimeDimensi @override def semantic_model_reference(self) -> Optional[SemanticModelReference]: return None + + def annotate(self, alias: str, metric_time_spec: TimeDimensionSpec) -> AnnotatedSqlDataSet: + """Convert to an AnnotatedSqlDataSet with specified metadata.""" + metric_time_column_name = self.column_association_for_time_dimension(metric_time_spec).column_name + return AnnotatedSqlDataSet(data_set=self, alias=alias, _metric_time_column_name=metric_time_column_name) + + +@dataclass(frozen=True) +class AnnotatedSqlDataSet: + """Class to bind a DataSet to transient properties associated with it at a given point in the SqlQueryPlan.""" + + data_set: SqlDataSet + alias: str + _metric_time_column_name: Optional[str] = None + + @property + def metric_time_column_name(self) -> str: + """Direct accessor for the optional metric time name, only safe to call when we know that value is set.""" + assert ( + self._metric_time_column_name + ), "Expected a valid metric time dimension name to be associated with this dataset, but did not get one!" + return self._metric_time_column_name diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index cea95e115..526618a9f 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -468,70 +468,41 @@ def visit_source_node(self, node: ReadSqlSourceNode) -> SqlDataSet: def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDataSet: """Generate time range join SQL.""" table_alias_to_instance_set: OrderedDict[str, InstanceSet] = OrderedDict() - input_data_set = node.parent_node.accept(self) - input_data_set_alias = self._next_unique_table_alias() + parent_data_set = node.parent_node.accept(self) + parent_data_set_alias = self._next_unique_table_alias() - # Find requested agg_time_dimensions in parent instance set. - # Will use instance with the smallest base granularity in time spine join. - agg_time_dimension_instance_for_join: Optional[TimeDimensionInstance] = None - requested_agg_time_dimension_instances: Tuple[TimeDimensionInstance, ...] = () - for instance in input_data_set.instance_set.time_dimension_instances: - if instance.spec in node.queried_agg_time_dimension_specs: - requested_agg_time_dimension_instances += (instance,) - if not agg_time_dimension_instance_for_join or ( - instance.spec.time_granularity.base_granularity.to_int() - < agg_time_dimension_instance_for_join.spec.time_granularity.base_granularity.to_int() - ): - agg_time_dimension_instance_for_join = instance - assert ( - agg_time_dimension_instance_for_join - ), "Specified metric time spec not found in parent data set. This should have been caught by validations." + # For the purposes of this node, use base grains. Custom grains will be joined later in the dataflow plan. + agg_time_dimension_specs = tuple({spec.with_base_grain() for spec in node.queried_agg_time_dimension_specs}) + # Assemble time_spine dataset with a column for each agg_time_dimension requested. + agg_time_dimension_instances = parent_data_set.instances_for_time_dimensions(agg_time_dimension_specs) time_spine_data_set_alias = self._next_unique_table_alias() - - # Assemble time_spine dataset with requested agg time dimension instances selected. time_spine_data_set = self._make_time_spine_data_set( - agg_time_dimension_instances=requested_agg_time_dimension_instances, - time_range_constraint=node.time_range_constraint, + agg_time_dimension_instances=agg_time_dimension_instances, time_range_constraint=node.time_range_constraint ) - table_alias_to_instance_set[time_spine_data_set_alias] = time_spine_data_set.instance_set + # Build the join description. + join_spec = self._choose_instance_for_time_spine_join(agg_time_dimension_instances).spec + annotated_parent = parent_data_set.annotate(alias=parent_data_set_alias, metric_time_spec=join_spec) + annotated_time_spine = time_spine_data_set.annotate(alias=time_spine_data_set_alias, metric_time_spec=join_spec) join_desc = SqlQueryPlanJoinBuilder.make_cumulative_metric_time_range_join_description( - node=node, - metric_data_set=AnnotatedSqlDataSet( - data_set=input_data_set, - alias=input_data_set_alias, - _metric_time_column_name=input_data_set.column_association_for_time_dimension( - agg_time_dimension_instance_for_join.spec - ).column_name, - ), - time_spine_data_set=AnnotatedSqlDataSet( - data_set=time_spine_data_set, - alias=time_spine_data_set_alias, - _metric_time_column_name=time_spine_data_set.column_association_for_time_dimension( - agg_time_dimension_instance_for_join.spec - ).column_name, - ), + node=node, metric_data_set=annotated_parent, time_spine_data_set=annotated_time_spine ) - # Remove instances of agg_time_dimension from input data set. They'll be replaced with time spine instances. - agg_time_dimension_specs = tuple(dim.spec for dim in requested_agg_time_dimension_instances) - modified_input_instance_set = input_data_set.instance_set.transform( + # Build select columns, replacing agg_time_dimensions from the parent node with columns from the time spine. + table_alias_to_instance_set[time_spine_data_set_alias] = time_spine_data_set.instance_set + table_alias_to_instance_set[parent_data_set_alias] = parent_data_set.instance_set.transform( FilterElements(exclude_specs=InstanceSpecSet(time_dimension_specs=agg_time_dimension_specs)) ) - table_alias_to_instance_set[input_data_set_alias] = modified_input_instance_set - - # The output instances are the same as the input instances. - output_instance_set = ChangeAssociatedColumns(self._column_association_resolver).transform( - input_data_set.instance_set + select_columns = create_simple_select_columns_for_instance_sets( + column_resolver=self._column_association_resolver, table_alias_to_instance_set=table_alias_to_instance_set ) + return SqlDataSet( - instance_set=output_instance_set, + instance_set=parent_data_set.instance_set, # The output instances are the same as the input instances. sql_select_node=SqlSelectStatementNode.create( description=node.description, - select_columns=create_simple_select_columns_for_instance_sets( - self._column_association_resolver, table_alias_to_instance_set - ), + select_columns=select_columns, from_source=time_spine_data_set.checked_sql_select_node, from_source_alias=time_spine_data_set_alias, join_descs=(join_desc,), @@ -1392,6 +1363,21 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe ), ) + def _choose_instance_for_time_spine_join( + self, agg_time_dimension_instances: Sequence[TimeDimensionInstance] + ) -> TimeDimensionInstance: + """Find the agg_time_dimension instance with the smallest grain to use for the time spine join.""" + # We can't use a date part spec to join to the time spine, so filter those out. + agg_time_dimension_instances = [ + instance for instance in agg_time_dimension_instances if not instance.spec.date_part + ] + assert len(agg_time_dimension_instances) > 0, ( + "No appropriate agg_time_dimension was found to join to the time spine. " + "This indicates that the dataflow plan was configured incorrectly." + ) + agg_time_dimension_instances.sort(key=lambda instance: instance.spec.time_granularity.base_granularity.to_int()) + return agg_time_dimension_instances[0] + def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet: # noqa: D102 parent_data_set = node.parent_node.accept(self) parent_alias = self._next_unique_table_alias() diff --git a/metricflow/plan_conversion/instance_converters.py b/metricflow/plan_conversion/instance_converters.py index c6e25835b..b801d958d 100644 --- a/metricflow/plan_conversion/instance_converters.py +++ b/metricflow/plan_conversion/instance_converters.py @@ -803,6 +803,7 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D102 ) +# TODO: delete this class & all uses. It doesn't do anything. class ChangeAssociatedColumns(InstanceSetTransform[InstanceSet]): """Change the columns associated with instances to the one specified by the resolver. diff --git a/metricflow/plan_conversion/sql_join_builder.py b/metricflow/plan_conversion/sql_join_builder.py index 8aa859823..284e20360 100644 --- a/metricflow/plan_conversion/sql_join_builder.py +++ b/metricflow/plan_conversion/sql_join_builder.py @@ -13,7 +13,7 @@ from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode from metricflow.dataflow.nodes.join_to_base import JoinDescription from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode -from metricflow.dataset.sql_dataset import SqlDataSet +from metricflow.dataset.sql_dataset import AnnotatedSqlDataSet from metricflow.plan_conversion.sql_expression_builders import make_coalesced_expr from metricflow.sql.sql_exprs import ( SqlColumnReference, @@ -46,23 +46,6 @@ class ColumnEqualityDescription: treat_nulls_as_equal: bool = False -@dataclass(frozen=True) -class AnnotatedSqlDataSet: - """Class to bind a DataSet to transient properties associated with it at a given point in the SqlQueryPlan.""" - - data_set: SqlDataSet - alias: str - _metric_time_column_name: Optional[str] = None - - @property - def metric_time_column_name(self) -> str: - """Direct accessor for the optional metric time name, only safe to call when we know that value is set.""" - assert ( - self._metric_time_column_name - ), "Expected a valid metric time dimension name to be associated with this dataset, but did not get one!" - return self._metric_time_column_name - - class SqlQueryPlanJoinBuilder: """Helper class for constructing various join components in a SqlQueryPlan."""