Skip to content

Commit

Permalink
Update JoinToTimeSpineNode to handle custom offset windows
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Dec 19, 2024
1 parent d0d4fa9 commit 5c4a1c8
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDim
window_function=self.window_function,
)

def with_window_function(self, window_function: SqlWindowFunction) -> TimeDimensionSpec: # noqa: D102
def with_window_function(self, window_function: Optional[SqlWindowFunction]) -> TimeDimensionSpec: # noqa: D102
return TimeDimensionSpec(
element_name=self.element_name,
entity_links=self.entity_links,
Expand Down
58 changes: 42 additions & 16 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,47 +1447,73 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
time_spine_alias = self._next_unique_table_alias()

required_agg_time_dimension_specs = tuple(node.requested_agg_time_dimension_specs)
if node.join_on_time_dimension_spec not in node.requested_agg_time_dimension_specs:
join_spec_was_requested = node.join_on_time_dimension_spec in node.requested_agg_time_dimension_specs
if not join_spec_was_requested:
required_agg_time_dimension_specs += (node.join_on_time_dimension_spec,)

# Build join expression.
join_column_name = self._column_association_resolver.resolve_spec(node.join_on_time_dimension_spec).column_name
parent_join_column_name = self._column_association_resolver.resolve_spec(
node.join_on_time_dimension_spec
).column_name
time_spine_jon_column_name = time_spine_data_set.instance_from_time_dimension_grain_and_date_part(
time_granularity_name=node.join_on_time_dimension_spec.time_granularity.name, date_part=None
).associated_column.column_name
join_description = SqlQueryPlanJoinBuilder.make_join_to_time_spine_join_description(
node=node,
time_spine_alias=time_spine_alias,
agg_time_dimension_column_name=join_column_name,
time_spine_column_name=time_spine_jon_column_name,
parent_column_name=parent_join_column_name,
parent_sql_select_node=parent_data_set.checked_sql_select_node,
parent_alias=parent_alias,
)

# Build combined instance set.
# Build new instances and columns.
time_spine_required_spec_set = InstanceSpecSet(time_dimension_specs=required_agg_time_dimension_specs)
parent_instance_set = parent_data_set.instance_set.transform(
output_parent_instance_set = parent_data_set.instance_set.transform(
FilterElements(exclude_specs=time_spine_required_spec_set)
)
time_spine_instance_set = time_spine_data_set.instance_set.transform(
FilterElements(include_specs=time_spine_required_spec_set)
)
output_instance_set = InstanceSet.merge([parent_instance_set, time_spine_instance_set])
output_time_spine_instances: Tuple[TimeDimensionInstance, ...] = ()
output_time_spine_columns: Tuple[SqlSelectColumn, ...] = ()
for old_instance in time_spine_data_set.instance_set.time_dimension_instances:
new_spec = old_instance.spec.with_window_function(None)
if new_spec not in required_agg_time_dimension_specs:
continue
if old_instance.spec.window_function:
new_instance = old_instance.with_new_spec(
new_spec=new_spec, column_association_resolver=self._column_association_resolver
)
column = SqlSelectColumn(
expr=SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias, column_name=old_instance.associated_column.column_name
),
column_alias=new_instance.associated_column.column_name,
)
else:
new_instance = old_instance
column = SqlSelectColumn.from_table_and_column_names(
table_alias=time_spine_alias, column_name=old_instance.associated_column.column_name
)
output_time_spine_instances += (new_instance,)
output_time_spine_columns += (column,)

# Build new simple select columns.
select_columns = create_simple_select_columns_for_instance_sets(
output_instance_set = InstanceSet.merge(
[output_parent_instance_set, InstanceSet(time_dimension_instances=output_time_spine_instances)]
)
select_columns = output_time_spine_columns + create_simple_select_columns_for_instance_sets(
self._column_association_resolver,
OrderedDict({parent_alias: parent_instance_set, time_spine_alias: time_spine_instance_set}),
OrderedDict({parent_alias: output_parent_instance_set}),
)

# If offset_to_grain is used, will need to filter down to rows that match selected granularities.
# Does not apply if one of the granularities selected matches the time spine column granularity.
where_filter: Optional[SqlExpressionNode] = None
need_where_filter = (
node.offset_to_grain and node.join_on_time_dimension_spec not in node.requested_agg_time_dimension_specs
)
need_where_filter = node.offset_to_grain and not join_spec_was_requested

# Filter down to one row per granularity period requested in the group by. Any other granularities
# included here will be filtered out before aggregation and so should not be included in where filter.
if need_where_filter:
join_column_expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias, column_name=join_column_name
table_alias=time_spine_alias, column_name=parent_join_column_name
)
for requested_spec in node.requested_agg_time_dimension_specs:
column_name = self._column_association_resolver.resolve_spec(requested_spec).column_name
Expand Down
7 changes: 4 additions & 3 deletions metricflow/plan_conversion/sql_join_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,13 +527,14 @@ def make_cumulative_metric_time_range_join_description(
def make_join_to_time_spine_join_description(
node: JoinToTimeSpineNode,
time_spine_alias: str,
agg_time_dimension_column_name: str,
time_spine_column_name: str,
parent_column_name: str,
parent_sql_select_node: SqlSelectStatementNode,
parent_alias: str,
) -> SqlJoinDescription:
"""Build join expression used to join a metric to a time spine dataset."""
left_expr: SqlExpressionNode = SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=agg_time_dimension_column_name)
col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=time_spine_column_name)
)
if node.offset_window:
left_expr = SqlSubtractTimeIntervalExpression.create(
Expand All @@ -551,7 +552,7 @@ def make_join_to_time_spine_join_description(
left_expr=left_expr,
comparison=SqlComparison.EQUALS,
right_expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias=parent_alias, column_name=agg_time_dimension_column_name)
col_ref=SqlColumnReference(table_alias=parent_alias, column_name=parent_column_name)
),
),
join_type=node.join_type,
Expand Down

0 comments on commit 5c4a1c8

Please sign in to comment.