From 5c4a1c84cbc01d524a80667202c93aea386c6a4a Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 18 Dec 2024 22:31:15 -0800 Subject: [PATCH] Update JoinToTimeSpineNode to handle custom offset windows --- .../specs/time_dimension_spec.py | 2 +- metricflow/plan_conversion/dataflow_to_sql.py | 58 ++++++++++++++----- .../plan_conversion/sql_join_builder.py | 7 ++- 3 files changed, 47 insertions(+), 20 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py b/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py index dec834adc..62211493d 100644 --- a/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py +++ b/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py @@ -195,7 +195,7 @@ def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDim window_function=self.window_function, ) - def with_window_function(self, window_function: SqlWindowFunction) -> TimeDimensionSpec: # noqa: D102 + def with_window_function(self, window_function: Optional[SqlWindowFunction]) -> TimeDimensionSpec: # noqa: D102 return TimeDimensionSpec( element_name=self.element_name, entity_links=self.entity_links, diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index b03d76e4d..c90495863 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -1447,47 +1447,73 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet time_spine_alias = self._next_unique_table_alias() required_agg_time_dimension_specs = tuple(node.requested_agg_time_dimension_specs) - if node.join_on_time_dimension_spec not in node.requested_agg_time_dimension_specs: + join_spec_was_requested = node.join_on_time_dimension_spec in node.requested_agg_time_dimension_specs + if not join_spec_was_requested: required_agg_time_dimension_specs += (node.join_on_time_dimension_spec,) # Build join expression. - join_column_name = self._column_association_resolver.resolve_spec(node.join_on_time_dimension_spec).column_name + parent_join_column_name = self._column_association_resolver.resolve_spec( + node.join_on_time_dimension_spec + ).column_name + time_spine_jon_column_name = time_spine_data_set.instance_from_time_dimension_grain_and_date_part( + time_granularity_name=node.join_on_time_dimension_spec.time_granularity.name, date_part=None + ).associated_column.column_name join_description = SqlQueryPlanJoinBuilder.make_join_to_time_spine_join_description( node=node, time_spine_alias=time_spine_alias, - agg_time_dimension_column_name=join_column_name, + time_spine_column_name=time_spine_jon_column_name, + parent_column_name=parent_join_column_name, parent_sql_select_node=parent_data_set.checked_sql_select_node, parent_alias=parent_alias, ) - # Build combined instance set. + # Build new instances and columns. time_spine_required_spec_set = InstanceSpecSet(time_dimension_specs=required_agg_time_dimension_specs) - parent_instance_set = parent_data_set.instance_set.transform( + output_parent_instance_set = parent_data_set.instance_set.transform( FilterElements(exclude_specs=time_spine_required_spec_set) ) - time_spine_instance_set = time_spine_data_set.instance_set.transform( - FilterElements(include_specs=time_spine_required_spec_set) - ) - output_instance_set = InstanceSet.merge([parent_instance_set, time_spine_instance_set]) + output_time_spine_instances: Tuple[TimeDimensionInstance, ...] = () + output_time_spine_columns: Tuple[SqlSelectColumn, ...] = () + for old_instance in time_spine_data_set.instance_set.time_dimension_instances: + new_spec = old_instance.spec.with_window_function(None) + if new_spec not in required_agg_time_dimension_specs: + continue + if old_instance.spec.window_function: + new_instance = old_instance.with_new_spec( + new_spec=new_spec, column_association_resolver=self._column_association_resolver + ) + column = SqlSelectColumn( + expr=SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=time_spine_alias, column_name=old_instance.associated_column.column_name + ), + column_alias=new_instance.associated_column.column_name, + ) + else: + new_instance = old_instance + column = SqlSelectColumn.from_table_and_column_names( + table_alias=time_spine_alias, column_name=old_instance.associated_column.column_name + ) + output_time_spine_instances += (new_instance,) + output_time_spine_columns += (column,) - # Build new simple select columns. - select_columns = create_simple_select_columns_for_instance_sets( + output_instance_set = InstanceSet.merge( + [output_parent_instance_set, InstanceSet(time_dimension_instances=output_time_spine_instances)] + ) + select_columns = output_time_spine_columns + create_simple_select_columns_for_instance_sets( self._column_association_resolver, - OrderedDict({parent_alias: parent_instance_set, time_spine_alias: time_spine_instance_set}), + OrderedDict({parent_alias: output_parent_instance_set}), ) # If offset_to_grain is used, will need to filter down to rows that match selected granularities. # Does not apply if one of the granularities selected matches the time spine column granularity. where_filter: Optional[SqlExpressionNode] = None - need_where_filter = ( - node.offset_to_grain and node.join_on_time_dimension_spec not in node.requested_agg_time_dimension_specs - ) + need_where_filter = node.offset_to_grain and not join_spec_was_requested # Filter down to one row per granularity period requested in the group by. Any other granularities # included here will be filtered out before aggregation and so should not be included in where filter. if need_where_filter: join_column_expr = SqlColumnReferenceExpression.from_table_and_column_names( - table_alias=time_spine_alias, column_name=join_column_name + table_alias=time_spine_alias, column_name=parent_join_column_name ) for requested_spec in node.requested_agg_time_dimension_specs: column_name = self._column_association_resolver.resolve_spec(requested_spec).column_name diff --git a/metricflow/plan_conversion/sql_join_builder.py b/metricflow/plan_conversion/sql_join_builder.py index f80cdf228..9f2e37743 100644 --- a/metricflow/plan_conversion/sql_join_builder.py +++ b/metricflow/plan_conversion/sql_join_builder.py @@ -527,13 +527,14 @@ def make_cumulative_metric_time_range_join_description( def make_join_to_time_spine_join_description( node: JoinToTimeSpineNode, time_spine_alias: str, - agg_time_dimension_column_name: str, + time_spine_column_name: str, + parent_column_name: str, parent_sql_select_node: SqlSelectStatementNode, parent_alias: str, ) -> SqlJoinDescription: """Build join expression used to join a metric to a time spine dataset.""" left_expr: SqlExpressionNode = SqlColumnReferenceExpression.create( - col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=agg_time_dimension_column_name) + col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=time_spine_column_name) ) if node.offset_window: left_expr = SqlSubtractTimeIntervalExpression.create( @@ -551,7 +552,7 @@ def make_join_to_time_spine_join_description( left_expr=left_expr, comparison=SqlComparison.EQUALS, right_expr=SqlColumnReferenceExpression.create( - col_ref=SqlColumnReference(table_alias=parent_alias, column_name=agg_time_dimension_column_name) + col_ref=SqlColumnReference(table_alias=parent_alias, column_name=parent_column_name) ), ), join_type=node.join_type,