From 52f93220ff7a7ffba311d4a89e5f9644f1260563 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Thu, 21 Nov 2024 17:58:37 -0800 Subject: [PATCH] WIP --- .../dataflow/builder/dataflow_plan_builder.py | 51 ++++++-- metricflow/dataflow/builder/source_node.py | 1 + .../dataflow/nodes/join_to_time_spine.py | 24 +--- metricflow/dataset/sql_dataset.py | 16 +++ metricflow/plan_conversion/dataflow_to_sql.py | 117 ++++++++++-------- .../plan_conversion/sql_join_builder.py | 7 +- 6 files changed, 129 insertions(+), 87 deletions(-) diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 79e1a81a3a..ccf1ea2712 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -87,8 +87,10 @@ from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinOnEntitiesNode from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode +from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode +from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode from metricflow.dataflow.nodes.where_filter import WhereConstraintNode from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode @@ -646,8 +648,11 @@ def _build_derived_metric_output_node( queried_agg_time_dimension_specs = queried_linkable_specs.included_agg_time_dimension_specs_for_metric( metric_reference=metric_spec.reference, metric_lookup=self._metric_lookup ) + time_spine_node = self._build_time_spine_node(queried_agg_time_dimension_specs) + # TODO: No where constraint needed here, but might need to apply distinct values if the base grain isn't selected. output_node = JoinToTimeSpineNode.create( parent_node=output_node, + time_spine_node=time_spine_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, offset_window=metric_spec.offset_window, offset_to_grain=metric_spec.offset_to_grain, @@ -1037,8 +1042,7 @@ def _find_source_node_recipe_non_cached( ) # If metric_time is requested without metrics, choose appropriate time spine node to select those values from. if linkable_specs_to_satisfy.metric_time_specs: - time_spine_source = self._choose_time_spine_source(linkable_specs_to_satisfy.metric_time_specs) - time_spine_node = self._source_node_set.time_spine_metric_time_nodes[time_spine_source.base_granularity] + time_spine_node = self._choose_time_spine_metric_time_node(linkable_specs_to_satisfy.metric_time_specs) candidate_nodes_for_right_side_of_join += [time_spine_node] candidate_nodes_for_left_side_of_join += [time_spine_node] default_join_type = SqlJoinType.FULL_OUTER @@ -1619,10 +1623,11 @@ def _build_aggregated_measure_from_measure_source_node( f"Expected {SqlJoinType.INNER} for joining to time spine before aggregation. Remove this if there's a " f"new use case." ) - # This also uses the original time range constraint due to the application of the time window intervals - # in join rendering + time_spine_node = self._build_time_spine_node(base_agg_time_dimension_specs) + # TODO: No where constraint needed here, but might need to apply distinct values if the base grain isn't selected. unaggregated_measure_node = JoinToTimeSpineNode.create( parent_node=unaggregated_measure_node, + time_spine_node=time_spine_node, requested_agg_time_dimension_specs=base_agg_time_dimension_specs, offset_window=before_aggregation_time_spine_join_description.offset_window, offset_to_grain=before_aggregation_time_spine_join_description.offset_to_grain, @@ -1684,19 +1689,25 @@ def _build_aggregated_measure_from_measure_source_node( else: non_agg_time_filters.append(filter_spec) - # TODO: split this node into TimeSpineSourceNode and JoinToTimeSpineNode - then can use standard nodes here - # like JoinToCustomGranularityNode, WhereConstraintNode, etc. queried_agg_time_dimension_specs = queried_linkable_specs.included_agg_time_dimension_specs_for_measure( measure_reference=measure_spec.reference, semantic_model_lookup=self._semantic_model_lookup ) + time_spine_node = self._build_time_spine_node(queried_agg_time_dimension_specs) + filtered_time_spine_node = self._build_pre_aggregation_plan( + source_node=time_spine_node, + # TODO: Also need the join on spec, right? Figure that out. + # filter_to_specs=InstanceSpecSet.create_from_specs(queried_agg_time_dimension_specs), + custom_granularity_specs=tuple( + spec for spec in queried_agg_time_dimension_specs if spec.time_granularity.is_custom_granularity + ), + time_range_constraint=predicate_pushdown_state.time_range_constraint, + where_filter_specs=agg_time_only_filters, + ) output_node: DataflowPlanNode = JoinToTimeSpineNode.create( parent_node=aggregate_measures_node, + time_spine_node=filtered_time_spine_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, join_type=after_aggregation_time_spine_join_description.join_type, - time_range_constraint=predicate_pushdown_state.time_range_constraint, - offset_window=after_aggregation_time_spine_join_description.offset_window, - offset_to_grain=after_aggregation_time_spine_join_description.offset_to_grain, - time_spine_filters=agg_time_only_filters, ) # Since new rows might have been added due to time spine join, re-apply constraints here. Only re-apply filters @@ -1812,3 +1823,23 @@ def _choose_time_spine_source(self, required_time_spine_specs: Sequence[TimeDime required_time_spine_specs=required_time_spine_specs, time_spine_sources=self._source_node_builder.time_spine_sources, ) + + def _choose_time_spine_metric_time_node( + self, required_time_spine_specs: Sequence[TimeDimensionSpec] + ) -> MetricTimeDimensionTransformNode: + """Return the MetricTimeDimensionTransform time spine node needed to satisfy the specs.""" + time_spine_source = self._choose_time_spine_source(required_time_spine_specs) + return self._source_node_set.time_spine_metric_time_nodes[time_spine_source.base_granularity] + + def _choose_time_spine_read_node(self, required_time_spine_specs: Sequence[TimeDimensionSpec]) -> ReadSqlSourceNode: + """Return the MetricTimeDimensionTransform time spine node needed to satisfy the specs.""" + time_spine_source = self._choose_time_spine_source(required_time_spine_specs) + return self._source_node_set.time_spine_read_nodes[time_spine_source.base_granularity.value] + + def _build_time_spine_node(self, required_time_spine_specs: Sequence[TimeDimensionSpec]) -> DataflowPlanNode: + """Return the time spine node needed to satisfy the specs.""" + original_time_spine_node = self._choose_time_spine_read_node(required_time_spine_specs) + # TODO: build this node. Transform columns to the requested ones + return TransformTimeDimensionsNode( + parent_node=original_time_spine_node, required_time_spine_specs=required_time_spine_specs + ) diff --git a/metricflow/dataflow/builder/source_node.py b/metricflow/dataflow/builder/source_node.py index 0192e21277..68f9f82d50 100644 --- a/metricflow/dataflow/builder/source_node.py +++ b/metricflow/dataflow/builder/source_node.py @@ -36,6 +36,7 @@ class SourceNodeSet: # Semantic models are 1:1 mapped to a ReadSqlSourceNode. source_nodes_for_group_by_item_queries: Tuple[DataflowPlanNode, ...] + # TODO: maybe this didn't need to have string keys, check later # Provides time spines that can be used to satisfy time spine joins, organized by granularity name. time_spine_read_nodes: Mapping[str, ReadSqlSourceNode] diff --git a/metricflow/dataflow/nodes/join_to_time_spine.py b/metricflow/dataflow/nodes/join_to_time_spine.py index fd86503813..4ab383f356 100644 --- a/metricflow/dataflow/nodes/join_to_time_spine.py +++ b/metricflow/dataflow/nodes/join_to_time_spine.py @@ -8,9 +8,7 @@ from dbt_semantic_interfaces.type_enums import TimeGranularity from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DisplayedProperty -from metricflow_semantics.filters.time_constraint import TimeRangeConstraint from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec -from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.visitor import VisitorOutputT @@ -30,12 +28,11 @@ class JoinToTimeSpineNode(DataflowPlanNode, ABC): offset_to_grain: Granularity period to offset the parent dataset to when joining to time spine. """ + time_spine_node: DataflowPlanNode requested_agg_time_dimension_specs: Sequence[TimeDimensionSpec] join_type: SqlJoinType - time_range_constraint: Optional[TimeRangeConstraint] offset_window: Optional[MetricTimeWindow] offset_to_grain: Optional[TimeGranularity] - time_spine_filters: Optional[Sequence[WhereFilterSpec]] = None def __post_init__(self) -> None: # noqa: D105 super().__post_init__() @@ -51,21 +48,19 @@ def __post_init__(self) -> None: # noqa: D105 @staticmethod def create( # noqa: D102 parent_node: DataflowPlanNode, + time_spine_node: DataflowPlanNode, requested_agg_time_dimension_specs: Sequence[TimeDimensionSpec], join_type: SqlJoinType, - time_range_constraint: Optional[TimeRangeConstraint] = None, offset_window: Optional[MetricTimeWindow] = None, offset_to_grain: Optional[TimeGranularity] = None, - time_spine_filters: Optional[Sequence[WhereFilterSpec]] = None, ) -> JoinToTimeSpineNode: return JoinToTimeSpineNode( parent_nodes=(parent_node,), + time_spine_node=time_spine_node, requested_agg_time_dimension_specs=tuple(requested_agg_time_dimension_specs), join_type=join_type, - time_range_constraint=time_range_constraint, offset_window=offset_window, offset_to_grain=offset_to_grain, - time_spine_filters=time_spine_filters, ) @classmethod @@ -89,14 +84,6 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 props += (DisplayedProperty("offset_window", self.offset_window),) if self.offset_to_grain: props += (DisplayedProperty("offset_to_grain", self.offset_to_grain),) - if self.time_range_constraint: - props += (DisplayedProperty("time_range_constraint", self.time_range_constraint),) - if self.time_spine_filters: - props += ( - DisplayedProperty( - "time_spine_filters", [time_spine_filter.where_sql for time_spine_filter in self.time_spine_filters] - ), - ) return props @property @@ -106,22 +93,19 @@ def parent_node(self) -> DataflowPlanNode: # noqa: D102 def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 return ( isinstance(other_node, self.__class__) - and other_node.time_range_constraint == self.time_range_constraint and other_node.offset_window == self.offset_window and other_node.offset_to_grain == self.offset_to_grain and other_node.requested_agg_time_dimension_specs == self.requested_agg_time_dimension_specs and other_node.join_type == self.join_type - and other_node.time_spine_filters == self.time_spine_filters ) def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> JoinToTimeSpineNode: # noqa: D102 assert len(new_parent_nodes) == 1 return JoinToTimeSpineNode.create( parent_node=new_parent_nodes[0], + time_spine_node=self.time_spine_node, requested_agg_time_dimension_specs=self.requested_agg_time_dimension_specs, - time_range_constraint=self.time_range_constraint, offset_window=self.offset_window, offset_to_grain=self.offset_to_grain, join_type=self.join_type, - time_spine_filters=self.time_spine_filters, ) diff --git a/metricflow/dataset/sql_dataset.py b/metricflow/dataset/sql_dataset.py index 2b86eb4799..6b52f57411 100644 --- a/metricflow/dataset/sql_dataset.py +++ b/metricflow/dataset/sql_dataset.py @@ -147,6 +147,22 @@ def instance_for_time_dimension(self, time_dimension_spec: TimeDimensionSpec) -> """Given the name of the time dimension, return the instance associated with it in the data set.""" return self.instances_for_time_dimensions((time_dimension_spec,))[0] + def instance_from_time_dimension_grain_and_date_part( + self, time_dimension_spec: TimeDimensionSpec + ) -> TimeDimensionInstance: + """Find instance in dataset that matches the grain and date part of the given time dimension spec.""" + for time_dimension_instance in self.instance_set.time_dimension_instances: + if ( + time_dimension_instance.spec.time_granularity == time_dimension_spec.time_granularity + and time_dimension_instance.spec.date_part == time_dimension_spec.date_part + ): + return time_dimension_instance + + raise RuntimeError( + f"Did not find a time dimension instance with matching grain and date part for spec: {time_dimension_spec}\n" + f"Instances available: {self.instance_set.time_dimension_instances}" + ) + def column_association_for_time_dimension(self, time_dimension_spec: TimeDimensionSpec) -> ColumnAssociation: """Given the name of the time dimension, return the set of columns associated with it in the data set.""" return self.instance_for_time_dimension(time_dimension_spec).associated_column diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index a92637f231..5facddee5e 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -1370,21 +1370,19 @@ def _choose_instance_for_time_spine_join( def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet: # noqa: D102 parent_data_set = node.parent_node.accept(self) parent_alias = self._next_unique_table_alias() - - agg_time_dimension_instances = parent_data_set.instances_for_time_dimensions( - node.requested_agg_time_dimension_specs - ) + time_spine_data_set = node.time_spine_node.accept(self) + time_spine_alias = self._next_unique_table_alias() # Select the dimension for the join from the parent node because it may not have been included in the request. # Default to using metric_time for the join if it was requested, otherwise use the agg_time_dimension. - included_metric_time_instances = [ - instance for instance in agg_time_dimension_instances if instance.spec.is_metric_time - ] - if included_metric_time_instances: - join_on_time_dimension_sample = included_metric_time_instances[0].spec + + # TODO: could use helper if this were a spec set + metric_time_specs = [spec for spec in node.requested_agg_time_dimension_specs if spec.is_metric_time] + if metric_time_specs: + join_on_time_dimension_sample = metric_time_specs[0] else: - join_on_time_dimension_sample = agg_time_dimension_instances[0].spec - agg_time_dimension_instance_for_join = self._choose_instance_for_time_spine_join( + join_on_time_dimension_sample = node.requested_agg_time_dimension_specs[0] + parent_agg_time_dimension_instance_for_join = self._choose_instance_for_time_spine_join( [ instance for instance in parent_data_set.instance_set.time_dimension_instances @@ -1392,75 +1390,86 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet and instance.spec.entity_links == join_on_time_dimension_sample.entity_links ] ) - if agg_time_dimension_instance_for_join not in agg_time_dimension_instances: - agg_time_dimension_instances = (agg_time_dimension_instance_for_join,) + agg_time_dimension_instances - - # Build time spine data set with just the agg_time_dimension instance needed for the join. - time_spine_alias = self._next_unique_table_alias() - time_spine_dataset = self._make_time_spine_data_set( - agg_time_dimension_instances=agg_time_dimension_instances, - time_range_constraint=node.time_range_constraint, - time_spine_where_constraints=node.time_spine_filters or (), + required_agg_time_dimension_specs = tuple(node.requested_agg_time_dimension_specs) + if parent_agg_time_dimension_instance_for_join.spec not in node.requested_agg_time_dimension_specs: + required_agg_time_dimension_specs += (parent_agg_time_dimension_instance_for_join.spec,) + parent_agg_time_dimension_instances = parent_data_set.instances_for_time_dimensions( + required_agg_time_dimension_specs ) # Build join expression. + time_spine_instance_for_join = time_spine_data_set.instance_from_time_dimension_grain_and_date_part( + parent_agg_time_dimension_instance_for_join.spec + ) join_description = SqlQueryPlanJoinBuilder.make_join_to_time_spine_join_description( node=node, time_spine_alias=time_spine_alias, - agg_time_dimension_column_name=self._column_association_resolver.resolve_spec( - agg_time_dimension_instance_for_join.spec - ).column_name, + time_spine_column_name=time_spine_instance_for_join.associated_column.column_name, parent_sql_select_node=parent_data_set.checked_sql_select_node, parent_alias=parent_alias, + parent_column_name=parent_agg_time_dimension_instance_for_join.associated_column.column_name, ) - # Remove time spine instances from parent instance set. - time_spine_instances = time_spine_dataset.instance_set - time_spine_specs = time_spine_instances.spec_set - parent_instance_set = parent_data_set.instance_set.transform(FilterElements(exclude_specs=time_spine_specs)) - - # Build select columns + # Remove time spine specs from parent instance set & build parent select columns. + parent_instance_set = parent_data_set.instance_set.transform( + FilterElements(exclude_specs=InstanceSpecSet(time_dimension_specs=required_agg_time_dimension_specs)) + ) select_columns = create_simple_select_columns_for_instance_sets( - self._column_association_resolver, - OrderedDict({parent_alias: parent_instance_set, time_spine_alias: time_spine_dataset.instance_set}), + self._column_association_resolver, OrderedDict({parent_alias: parent_instance_set}) ) # If offset_to_grain is used, will need to filter down to rows that match selected granularities. # Does not apply if one of the granularities selected matches the time spine column granularity. where_filter: Optional[SqlExpressionNode] = None need_where_filter = ( + # should this be required_agg_time_dimension_specs? node.offset_to_grain - and agg_time_dimension_instance_for_join.spec not in node.requested_agg_time_dimension_specs + and time_spine_instance_for_join.spec not in node.requested_agg_time_dimension_specs ) - if need_where_filter: - join_column_expr = SqlColumnReferenceExpression.from_table_and_column_names( + join_column_expr = SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=time_spine_alias, + column_name=time_spine_instance_for_join.associated_column.column_name, + ) + + # Build new select columns and instances for the time spine. + # TODO: do I need to assert the length of the list here? + time_spine_defined_from = time_spine_data_set.instance_set.time_dimension_instances[0].defined_from + new_time_spine_instances: Tuple[TimeDimensionInstance, ...] = () + for parent_instance in parent_agg_time_dimension_instances: + new_time_spine_instances += (parent_instance.with_new_defined_from(time_spine_defined_from),) + time_spine_instance = time_spine_data_set.instance_from_time_dimension_grain_and_date_part( + parent_instance.spec + ) + expr = SqlColumnReferenceExpression.from_table_and_column_names( table_alias=time_spine_alias, - column_name=agg_time_dimension_instance_for_join.associated_column.column_name, + column_name=time_spine_instance.associated_column.column_name, ) - for time_spine_instance in time_spine_instances.as_tuple: - # Filter down to one row per granularity period requested in the group by. Any other granularities - # included here will be filtered out in later nodes so should not be included in where filter. - if need_where_filter and time_spine_instance.spec in node.requested_agg_time_dimension_specs: - column_to_filter_expr = SqlColumnReferenceExpression.from_table_and_column_names( - table_alias=time_spine_alias, column_name=time_spine_instance.associated_column.column_name - ) - new_where_filter = SqlComparisonExpression.create( - left_expr=column_to_filter_expr, comparison=SqlComparison.EQUALS, right_expr=join_column_expr - ) - where_filter = ( - SqlLogicalExpression.create( - operator=SqlLogicalOperator.OR, args=(where_filter, new_where_filter) - ) - if where_filter - else new_where_filter - ) + # Change the column names from time spine column name to whatever was requested by the user. + select_columns += (SqlSelectColumn(expr=expr, column_alias=parent_instance.associated_column.column_name),) + + # Filter down to one row per granularity period requested in the group by. Any other granularities + # included here will be filtered out in later nodes so should not be included in where filter. + if need_where_filter and time_spine_instance.spec in node.requested_agg_time_dimension_specs: + column_to_filter_expr = SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=time_spine_alias, column_name=time_spine_instance.associated_column.column_name + ) + new_where_filter = SqlComparisonExpression.create( + left_expr=column_to_filter_expr, comparison=SqlComparison.EQUALS, right_expr=join_column_expr + ) + where_filter = ( + SqlLogicalExpression.create(operator=SqlLogicalOperator.OR, args=(where_filter, new_where_filter)) + if where_filter + else new_where_filter + ) return SqlDataSet( - instance_set=InstanceSet.merge([time_spine_dataset.instance_set, parent_instance_set]), + instance_set=InstanceSet.merge( + [InstanceSet(time_dimension_instances=new_time_spine_instances), parent_instance_set] + ), sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=select_columns, - from_source=time_spine_dataset.checked_sql_select_node, + from_source=time_spine_data_set.checked_sql_select_node, from_source_alias=time_spine_alias, join_descs=(join_description,), where=where_filter, diff --git a/metricflow/plan_conversion/sql_join_builder.py b/metricflow/plan_conversion/sql_join_builder.py index 4da3945e3a..7bdf8e80e3 100644 --- a/metricflow/plan_conversion/sql_join_builder.py +++ b/metricflow/plan_conversion/sql_join_builder.py @@ -524,13 +524,14 @@ def make_cumulative_metric_time_range_join_description( def make_join_to_time_spine_join_description( node: JoinToTimeSpineNode, time_spine_alias: str, - agg_time_dimension_column_name: str, + time_spine_column_name: str, parent_sql_select_node: SqlSelectStatementNode, parent_alias: str, + parent_column_name: str, ) -> SqlJoinDescription: """Build join expression used to join a metric to a time spine dataset.""" left_expr: SqlExpressionNode = SqlColumnReferenceExpression.create( - col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=agg_time_dimension_column_name) + col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=time_spine_column_name) ) if node.offset_window: left_expr = SqlSubtractTimeIntervalExpression.create( @@ -546,7 +547,7 @@ def make_join_to_time_spine_join_description( left_expr=left_expr, comparison=SqlComparison.EQUALS, right_expr=SqlColumnReferenceExpression.create( - col_ref=SqlColumnReference(table_alias=parent_alias, column_name=agg_time_dimension_column_name) + col_ref=SqlColumnReference(table_alias=parent_alias, column_name=parent_column_name) ), ), join_type=node.join_type,