Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Join to Time Spine & Fill Nulls #832

Merged
merged 13 commits into from
Nov 2, 2023
Merged
101 changes: 56 additions & 45 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List, Optional, Sequence, Union

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.protocols.metric import MetricType
from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType
from dbt_semantic_interfaces.references import MetricModelReference
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType

Expand Down Expand Up @@ -81,6 +81,7 @@
SqlQueryOptimizerConfiguration,
)
from metricflow.sql.sql_exprs import (
SqlAggregateFunctionExpression,
SqlBetweenExpression,
SqlColumnReference,
SqlColumnReferenceExpression,
Expand All @@ -89,6 +90,7 @@
SqlDateTruncExpression,
SqlExpressionNode,
SqlExtractExpression,
SqlFunction,
SqlFunctionExpression,
SqlLogicalExpression,
SqlLogicalOperator,
Expand Down Expand Up @@ -633,6 +635,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet:
metric = self._metric_lookup.get_metric(metric_spec.as_reference)

metric_expr: Optional[SqlExpressionNode] = None
input_measure: Optional[MetricInputMeasure] = None
if metric.type is MetricType.RATIO:
numerator = metric.type_params.numerator
denominator = metric.type_params.denominator
Expand Down Expand Up @@ -664,33 +667,26 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet:
if len(metric.input_measures) > 0:
assert (
len(metric.input_measures) == 1
), "Measure proxy metrics should always source from exactly 1 measure."
courtneyholcomb marked this conversation as resolved.
Show resolved Hide resolved
), "Simple metrics should always source from exactly 1 measure."
input_measure = metric.input_measures[0]
expr = self._column_association_resolver.resolve_spec(
MeasureSpec(
element_name=metric.input_measures[0].post_aggregation_measure_reference.element_name
)
MeasureSpec(element_name=input_measure.post_aggregation_measure_reference.element_name)
).column_name
else:
expr = metric.name
# Use a column reference to improve query optimization.
metric_expr = SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=from_data_set_alias,
column_name=expr,
)
metric_expr = self.__make_col_reference_or_coalesce_expr(
column_name=expr, input_measure=input_measure, from_data_set_alias=from_data_set_alias
)
elif metric.type is MetricType.CUMULATIVE:
assert (
len(metric.measure_references) == 1
), "Cumulative metrics should always source from exactly 1 measure."
input_measure = metric.input_measures[0]
expr = self._column_association_resolver.resolve_spec(
MeasureSpec(element_name=metric.input_measures[0].post_aggregation_measure_reference.element_name)
MeasureSpec(element_name=input_measure.post_aggregation_measure_reference.element_name)
).column_name
metric_expr = SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=from_data_set_alias,
column_name=expr,
)
metric_expr = self.__make_col_reference_or_coalesce_expr(
column_name=expr, input_measure=input_measure, from_data_set_alias=from_data_set_alias
)
elif metric.type is MetricType.DERIVED:
assert metric.type_params.expr
Expand Down Expand Up @@ -734,6 +730,21 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet:
),
)

def __make_col_reference_or_coalesce_expr(
self, column_name: str, input_measure: Optional[MetricInputMeasure], from_data_set_alias: str
) -> SqlExpressionNode:
# Use a column reference to improve query optimization.
metric_expr: SqlExpressionNode = SqlColumnReferenceExpression(
SqlColumnReference(table_alias=from_data_set_alias, column_name=column_name)
)
# Coalesce nulls to requested integer value, if requested.
if input_measure and input_measure.fill_nulls_with is not None:
metric_expr = SqlAggregateFunctionExpression(
courtneyholcomb marked this conversation as resolved.
Show resolved Hide resolved
sql_function=SqlFunction.COALESCE,
sql_function_args=[metric_expr, SqlStringExpression(str(input_measure.fill_nulls_with))],
)
return metric_expr

def visit_order_by_limit_node(self, node: OrderByLimitNode) -> SqlDataSet: # noqa: D
from_data_set: SqlDataSet = node.parent_node.accept(self)
output_instance_set = from_data_set.instance_set
Expand Down Expand Up @@ -1312,7 +1323,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
metric_time_dimension_instance = instance
assert (
metric_time_dimension_instance
), "Can't query offset metric without a time dimension. Validations should have prevented this."
), "Can't join to time spine without metric time. Validations should have prevented this."
metric_time_dimension_column_name = self.column_association_resolver.resolve_spec(
metric_time_dimension_instance.spec
).column_name
Expand Down Expand Up @@ -1346,33 +1357,33 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
metric_instances=parent_data_set.instance_set.metric_instances,
metadata_instances=parent_data_set.instance_set.metadata_instances,
)
non_metric_time_select_columns = create_select_columns_for_instance_sets(
parent_select_columns = create_select_columns_for_instance_sets(
self._column_association_resolver, OrderedDict({parent_alias: non_metric_time_parent_instance_set})
)

# Use metric_time column from time spine.
# Use time instance from time spine to replace metric_time instances.
assert (
len(time_spine_dataset.instance_set.time_dimension_instances) == 1
and len(time_spine_dataset.sql_select_node.select_columns) == 1
), "Time spine dataset not configured properly. Expected exactly one column."
time_dim_instance = time_spine_dataset.instance_set.time_dimension_instances[0]
time_spine_dim_instance = time_spine_dataset.instance_set.time_dimension_instances[0]
time_spine_column_select_expr: Union[
SqlColumnReferenceExpression, SqlDateTruncExpression
] = SqlColumnReferenceExpression(
SqlColumnReference(table_alias=time_spine_alias, column_name=time_dim_instance.spec.qualified_name)
SqlColumnReference(table_alias=time_spine_alias, column_name=time_spine_dim_instance.spec.qualified_name)
)

# Add requested granularities (skip for default granularity) and date_parts.
metric_time_select_columns = []
metric_time_dimension_instances = []
# Add requested granularities (if different from time_spine) and date_parts to time spine column.
time_spine_select_columns = []
time_spine_dim_instances = []
where: Optional[SqlExpressionNode] = None
for metric_time_dimension_spec in node.requested_metric_time_dimension_specs:
# Apply granularity to SQL.
if metric_time_dimension_spec.time_granularity == self._time_spine_source.time_column_granularity:
for requested_time_dimension_spec in node.requested_metric_time_dimension_specs:
# Apply granularity to time spine column select expression.
if requested_time_dimension_spec.time_granularity == time_spine_dim_instance.spec.time_granularity:
select_expr: SqlExpressionNode = time_spine_column_select_expr
else:
select_expr = SqlDateTruncExpression(
time_granularity=metric_time_dimension_spec.time_granularity, arg=time_spine_column_select_expr
time_granularity=requested_time_dimension_spec.time_granularity, arg=time_spine_column_select_expr
)
if node.offset_to_grain:
# Filter down to one row per granularity period
Expand All @@ -1383,32 +1394,32 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
where = new_filter
else:
where = SqlLogicalExpression(operator=SqlLogicalOperator.OR, args=(where, new_filter))
# Apply date_part to SQL.
if metric_time_dimension_spec.date_part:
select_expr = SqlExtractExpression(date_part=metric_time_dimension_spec.date_part, arg=select_expr)
# Apply date_part to time spine column select expression.
if requested_time_dimension_spec.date_part:
select_expr = SqlExtractExpression(date_part=requested_time_dimension_spec.date_part, arg=select_expr)
time_dim_spec = TimeDimensionSpec(
element_name=time_dim_instance.spec.element_name,
entity_links=time_dim_instance.spec.entity_links,
time_granularity=metric_time_dimension_spec.time_granularity,
date_part=metric_time_dimension_spec.date_part,
aggregation_state=time_dim_instance.spec.aggregation_state,
element_name=time_spine_dim_instance.spec.element_name,
entity_links=time_spine_dim_instance.spec.entity_links,
time_granularity=requested_time_dimension_spec.time_granularity,
date_part=requested_time_dimension_spec.date_part,
aggregation_state=time_spine_dim_instance.spec.aggregation_state,
)
time_dim_instance = TimeDimensionInstance(
defined_from=time_dim_instance.defined_from,
time_spine_dim_instance = TimeDimensionInstance(
defined_from=time_spine_dim_instance.defined_from,
associated_columns=(self._column_association_resolver.resolve_spec(time_dim_spec),),
spec=time_dim_spec,
)
metric_time_dimension_instances.append(time_dim_instance)
metric_time_select_columns.append(
SqlSelectColumn(expr=select_expr, column_alias=time_dim_instance.associated_column.column_name)
time_spine_dim_instances.append(time_spine_dim_instance)
time_spine_select_columns.append(
SqlSelectColumn(expr=select_expr, column_alias=time_spine_dim_instance.associated_column.column_name)
)
metric_time_instance_set = InstanceSet(time_dimension_instances=tuple(metric_time_dimension_instances))
time_spine_instance_set = InstanceSet(time_dimension_instances=tuple(time_spine_dim_instances))

return SqlDataSet(
instance_set=InstanceSet.merge([metric_time_instance_set, non_metric_time_parent_instance_set]),
instance_set=InstanceSet.merge([time_spine_instance_set, non_metric_time_parent_instance_set]),
sql_select_node=SqlSelectStatementNode(
description=node.description,
select_columns=tuple(metric_time_select_columns) + non_metric_time_select_columns,
select_columns=tuple(time_spine_select_columns) + parent_select_columns,
from_source=time_spine_dataset.sql_select_node,
from_source_alias=time_spine_alias,
joins_descs=(join_description,),
Expand Down
194 changes: 194 additions & 0 deletions metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1998,3 +1998,197 @@ def test_offset_window_with_date_part( # noqa: D
sql_client=sql_client,
node=dataflow_plan.sink_output_nodes[0].parent_node,
)


@pytest.mark.sql_engine_snapshot
def test_simple_fill_nulls_with_0_metric_time( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
sql_client: SqlClient,
) -> None:
dataflow_plan = dataflow_plan_builder.build_plan(
MetricFlowQuerySpec(
metric_specs=(MetricSpec(element_name="bookings_fill_0"),),
time_dimension_specs=(DataSet.metric_time_dimension_spec(time_granularity=TimeGranularity.DAY),),
)
)

convert_and_check(
request=request,
mf_test_session_state=mf_test_session_state,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
node=dataflow_plan.sink_output_nodes[0].parent_node,
)


@pytest.mark.sql_engine_snapshot
def test_simple_fill_nulls_with_0_month( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
sql_client: SqlClient,
) -> None:
dataflow_plan = dataflow_plan_builder.build_plan(
MetricFlowQuerySpec(
metric_specs=(MetricSpec(element_name="bookings_fill_0"),),
time_dimension_specs=(DataSet.metric_time_dimension_spec(time_granularity=TimeGranularity.MONTH),),
)
)

convert_and_check(
request=request,
mf_test_session_state=mf_test_session_state,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
node=dataflow_plan.sink_output_nodes[0].parent_node,
)


@pytest.mark.sql_engine_snapshot
def test_simple_fill_0_with_non_metric_time( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
sql_client: SqlClient,
) -> None:
dataflow_plan = dataflow_plan_builder.build_plan(
MetricFlowQuerySpec(
metric_specs=(MetricSpec(element_name="bookings_fill_0"),),
time_dimension_specs=(
TimeDimensionSpec(element_name="paid_at", entity_links=(EntityReference("booking"),)),
),
)
)

convert_and_check(
request=request,
mf_test_session_state=mf_test_session_state,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
node=dataflow_plan.sink_output_nodes[0].parent_node,
)


@pytest.mark.sql_engine_snapshot
def test_simple_fill_0_with_categorical_dimension( # noqa: D
courtneyholcomb marked this conversation as resolved.
Show resolved Hide resolved
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
sql_client: SqlClient,
) -> None:
dataflow_plan = dataflow_plan_builder.build_plan(
MetricFlowQuerySpec(
metric_specs=(MetricSpec(element_name="bookings_fill_0"),),
dimension_specs=(DimensionSpec(element_name="is_instant", entity_links=(EntityReference("booking"),)),),
)
)

convert_and_check(
request=request,
mf_test_session_state=mf_test_session_state,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
node=dataflow_plan.sink_output_nodes[0].parent_node,
)


@pytest.mark.sql_engine_snapshot
def test_simple_join_to_time_spine( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
sql_client: SqlClient,
) -> None:
dataflow_plan = dataflow_plan_builder.build_plan(
MetricFlowQuerySpec(
metric_specs=(MetricSpec(element_name="bookings_join_to_time_spine"),),
time_dimension_specs=(DataSet.metric_time_dimension_spec(time_granularity=TimeGranularity.DAY),),
)
)

convert_and_check(
request=request,
mf_test_session_state=mf_test_session_state,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
node=dataflow_plan.sink_output_nodes[0].parent_node,
)


@pytest.mark.sql_engine_snapshot
def test_simple_fill_nulls_without_time_spine( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
sql_client: SqlClient,
) -> None:
dataflow_plan = dataflow_plan_builder.build_plan(
MetricFlowQuerySpec(
metric_specs=(MetricSpec(element_name="bookings_fill_0_without_time_spine"),),
time_dimension_specs=(DataSet.metric_time_dimension_spec(time_granularity=TimeGranularity.DAY),),
)
)

convert_and_check(
request=request,
mf_test_session_state=mf_test_session_state,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
node=dataflow_plan.sink_output_nodes[0].parent_node,
)


@pytest.mark.sql_engine_snapshot
def test_cumulative_fill_nulls( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
sql_client: SqlClient,
) -> None:
dataflow_plan = dataflow_plan_builder.build_plan(
MetricFlowQuerySpec(
metric_specs=(MetricSpec(element_name="every_two_days_bookers_fill_0"),),
time_dimension_specs=(DataSet.metric_time_dimension_spec(time_granularity=TimeGranularity.DAY),),
)
)

convert_and_check(
request=request,
mf_test_session_state=mf_test_session_state,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
node=dataflow_plan.sink_output_nodes[0].parent_node,
)


@pytest.mark.sql_engine_snapshot
def test_derived_fill_nulls_for_one_input_metric( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
sql_client: SqlClient,
) -> None:
dataflow_plan = dataflow_plan_builder.build_plan(
MetricFlowQuerySpec(
metric_specs=(MetricSpec(element_name="bookings_growth_2_weeks_fill_0_for_non_offset"),),
time_dimension_specs=(DataSet.metric_time_dimension_spec(time_granularity=TimeGranularity.DAY),),
)
)

convert_and_check(
request=request,
mf_test_session_state=mf_test_session_state,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
node=dataflow_plan.sink_output_nodes[0].parent_node,
)
Loading