Skip to content

Commit

Permalink
Support for multiple queried agg time dimensions in JoinOverTimeRange…
Browse files Browse the repository at this point in the history
… SQL rendering
  • Loading branch information
courtneyholcomb committed Jun 12, 2024
1 parent 18d3e71 commit f6fb1b6
Showing 1 changed file with 48 additions and 38 deletions.
86 changes: 48 additions & 38 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,16 @@ def _next_unique_table_alias(self) -> str:

def _make_time_spine_data_set(
self,
agg_time_dimension_instance: TimeDimensionInstance,
agg_time_dimension_instances: Tuple[TimeDimensionInstance, ...],
time_spine_source: TimeSpineSource,
time_range_constraint: Optional[TimeRangeConstraint] = None,
) -> SqlDataSet:
"""Make a time spine data set, which contains all date/time values like '2020-01-01', '2020-01-02'...
Returns a data set with a column for the agg_time_dimension requested.
Returns a dataset with a column selected for each agg_time_dimension requested.
Column alias will use 'metric_time' or the agg_time_dimension name depending on which the user requested.
"""
time_spine_instance_set = InstanceSet(time_dimension_instances=(agg_time_dimension_instance,))
time_spine_instance_set = InstanceSet(time_dimension_instances=agg_time_dimension_instances)
time_spine_table_alias = self._next_unique_table_alias()

column_expr = SqlColumnReferenceExpression.from_table_and_column_names(
Expand All @@ -240,22 +240,23 @@ def _make_time_spine_data_set(

select_columns: Tuple[SqlSelectColumn, ...] = ()
apply_group_by = False
column_alias = self.column_association_resolver.resolve_spec(agg_time_dimension_instance.spec).column_name
# If the requested granularity matches that of the time spine, do a direct select.
# TODO: also handle date part.
if agg_time_dimension_instance.spec.time_granularity == time_spine_source.time_column_granularity:
select_columns += (SqlSelectColumn(expr=column_expr, column_alias=column_alias),)
# Otherwise, apply a DATE_TRUNC() and aggregate via group_by.
else:
select_columns += (
SqlSelectColumn(
expr=SqlDateTruncExpression(
time_granularity=agg_time_dimension_instance.spec.time_granularity, arg=column_expr
for agg_time_dimension_instance in agg_time_dimension_instances:
column_alias = self.column_association_resolver.resolve_spec(agg_time_dimension_instance.spec).column_name
# If the requested granularity is the same as the granularity of the spine, do a direct select.
# TODO: also handle date part.
if agg_time_dimension_instance.spec.time_granularity == time_spine_source.time_column_granularity:
select_columns += (SqlSelectColumn(expr=column_expr, column_alias=column_alias),)
# If any columns have a different granularity, apply a DATE_TRUNC() and aggregate via group_by.
else:
select_columns += (
SqlSelectColumn(
expr=SqlDateTruncExpression(
time_granularity=agg_time_dimension_instance.spec.time_granularity, arg=column_expr
),
column_alias=column_alias,
),
column_alias=column_alias,
),
)
apply_group_by = True
)
apply_group_by = True

return SqlDataSet(
instance_set=time_spine_instance_set,
Expand Down Expand Up @@ -290,20 +291,28 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat
input_data_set = node.parent_node.accept(self)
input_data_set_alias = self._next_unique_table_alias()

agg_time_dimension_instance: Optional[TimeDimensionInstance] = None
# Find requested agg_time_dimensions in parent instance set.
# For now, will use instance with smallest granularity in time spine join.
# TODO: use metric's default_grain once that property is available.
agg_time_dimension_instance_for_join: Optional[TimeDimensionInstance] = None
requested_agg_time_dimension_instances: Tuple[TimeDimensionInstance, ...] = ()
for instance in input_data_set.instance_set.time_dimension_instances:
if instance.spec == node.time_dimension_spec_for_join:
agg_time_dimension_instance = instance
break
if instance.spec in node.queried_agg_time_dimension_specs:
requested_agg_time_dimension_instances += (instance,)
if not agg_time_dimension_instance_for_join or (
instance.spec.time_granularity.to_int()
< agg_time_dimension_instance_for_join.spec.time_granularity.to_int()
):
agg_time_dimension_instance_for_join = instance
assert (
agg_time_dimension_instance
agg_time_dimension_instance_for_join
), "Specified metric time spec not found in parent data set. This should have been caught by validations."

time_spine_data_set_alias = self._next_unique_table_alias()

# Assemble time_spine dataset with agg_time_dimension_instance selected.
# Assemble time_spine dataset with agg_time_dimension_instance_for_join selected.
time_spine_data_set = self._make_time_spine_data_set(
agg_time_dimension_instance=agg_time_dimension_instance,
agg_time_dimension_instances=requested_agg_time_dimension_instances,
time_spine_source=self._time_spine_source,
time_range_constraint=node.time_range_constraint,
)
Expand All @@ -315,21 +324,22 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat
data_set=input_data_set,
alias=input_data_set_alias,
_metric_time_column_name=input_data_set.column_association_for_time_dimension(
agg_time_dimension_instance.spec
agg_time_dimension_instance_for_join.spec
).column_name,
),
time_spine_data_set=AnnotatedSqlDataSet(
data_set=time_spine_data_set,
alias=time_spine_data_set_alias,
_metric_time_column_name=time_spine_data_set.column_association_for_time_dimension(
agg_time_dimension_instance.spec
agg_time_dimension_instance_for_join.spec
).column_name,
),
)

# Remove agg_time_dimension from input data set. It will be replaced with the time spine instance.
# Remove instances of agg_time_dimension from input data set. They'll be replaced with time spine instances.
agg_time_dimension_specs = tuple(dim.spec for dim in requested_agg_time_dimension_instances)
modified_input_instance_set = input_data_set.instance_set.transform(
FilterElements(exclude_specs=InstanceSpecSet(time_dimension_specs=(agg_time_dimension_instance.spec,)))
FilterElements(exclude_specs=InstanceSpecSet(time_dimension_specs=agg_time_dimension_specs))
)
table_alias_to_instance_set[input_data_set_alias] = modified_input_instance_set

Expand Down Expand Up @@ -1042,9 +1052,9 @@ def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTr
spec=metric_time_dimension_spec,
)
)
output_column_to_input_column[
metric_time_dimension_column_association.column_name
] = matching_time_dimension_instance.associated_column.column_name
output_column_to_input_column[metric_time_dimension_column_association.column_name] = (
matching_time_dimension_instance.associated_column.column_name
)

output_instance_set = InstanceSet(
measure_instances=tuple(output_measure_instances),
Expand Down Expand Up @@ -1222,7 +1232,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
# Build time spine data set using the requested agg_time_dimension name.
time_spine_alias = self._next_unique_table_alias()
time_spine_dataset = self._make_time_spine_data_set(
agg_time_dimension_instance=agg_time_dimension_instance_for_join,
agg_time_dimension_instances=(agg_time_dimension_instance_for_join,),
time_spine_source=self._time_spine_source,
time_range_constraint=node.time_range_constraint,
)
Expand Down Expand Up @@ -1276,11 +1286,11 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
and len(time_spine_dataset.checked_sql_select_node.select_columns) == 1
), "Time spine dataset not configured properly. Expected exactly one column."
original_time_spine_dim_instance = time_spine_dataset.instance_set.time_dimension_instances[0]
time_spine_column_select_expr: Union[
SqlColumnReferenceExpression, SqlDateTruncExpression
] = SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=time_spine_alias, column_name=original_time_spine_dim_instance.spec.qualified_name
time_spine_column_select_expr: Union[SqlColumnReferenceExpression, SqlDateTruncExpression] = (
SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=time_spine_alias, column_name=original_time_spine_dim_instance.spec.qualified_name
)
)
)

Expand Down

0 comments on commit f6fb1b6

Please sign in to comment.