Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix: querying multiple granularities with offset metrics #1054

Merged
merged 10 commits into from
Feb 29, 2024
7 changes: 7 additions & 0 deletions .changes/unreleased/Fixes-20240227-181223.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Fixes
body: Enable querying offset metric with multiple agg_time_dimensions at once. Also
fixes a bug when filtering by a different grain than the group by grain.
time: 2024-02-27T18:12:23.601203-08:00
custom:
Author: courtneyholcomb
Issue: 1052 1053
119 changes: 78 additions & 41 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,19 +216,18 @@ def _make_time_spine_data_set(
This is useful in computing cumulative metrics. This will need to be updated to support granularities finer than a
day.
"""
time_spine_instance = (
TimeDimensionInstance(
defined_from=agg_time_dimension_instance.defined_from,
associated_columns=(
ColumnAssociation(
column_name=agg_time_dimension_column_name,
single_column_correlation_key=SingleColumnCorrelationKey(),
),
time_spine_instance = TimeDimensionInstance(
defined_from=agg_time_dimension_instance.defined_from,
associated_columns=(
ColumnAssociation(
column_name=agg_time_dimension_column_name,
single_column_correlation_key=SingleColumnCorrelationKey(),
),
spec=agg_time_dimension_instance.spec,
),
spec=agg_time_dimension_instance.spec,
)
time_spine_instance_set = InstanceSet(time_dimension_instances=time_spine_instance)

time_spine_instance_set = InstanceSet(time_dimension_instances=(time_spine_instance,))
time_spine_table_alias = self._next_unique_table_alias()

# If the requested granularity is the same as the granularity of the spine, do a direct select.
Expand Down Expand Up @@ -1298,8 +1297,18 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
parent_alias=parent_alias,
)

# Select all instances from the parent data set, EXCEPT the requested agg_time_dimension.
# The agg_time_dimension will be selected from the time spine data set.
# Select all instances from the parent data set, EXCEPT agg_time_dimensions.
# The agg_time_dimensions will be selected from the time spine data set.
time_dimensions_to_select_from_parent: Tuple[TimeDimensionInstance, ...] = ()
time_dimensions_to_select_from_time_spine: Tuple[TimeDimensionInstance, ...] = ()
for time_dimension_instance in parent_data_set.instance_set.time_dimension_instances:
if (
time_dimension_instance.spec.element_name == agg_time_element_name
and time_dimension_instance.spec.entity_links == agg_time_entity_links
):
time_dimensions_to_select_from_time_spine += (time_dimension_instance,)
else:
time_dimensions_to_select_from_parent += (time_dimension_instance,)
parent_instance_set = InstanceSet(
measure_instances=parent_data_set.instance_set.measure_instances,
dimension_instances=parent_data_set.instance_set.dimension_instances,
Expand All @@ -1324,46 +1333,74 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
len(time_spine_dataset.instance_set.time_dimension_instances) == 1
and len(time_spine_dataset.sql_select_node.select_columns) == 1
), "Time spine dataset not configured properly. Expected exactly one column."
time_spine_dim_instance = time_spine_dataset.instance_set.time_dimension_instances[0]
original_time_spine_dim_instance = time_spine_dataset.instance_set.time_dimension_instances[0]
time_spine_column_select_expr: Union[
SqlColumnReferenceExpression, SqlDateTruncExpression
] = SqlColumnReferenceExpression(
SqlColumnReference(table_alias=time_spine_alias, column_name=time_spine_dim_instance.spec.qualified_name)
SqlColumnReference(
table_alias=time_spine_alias, column_name=original_time_spine_dim_instance.spec.qualified_name
)
)

# Add requested granularities (if different from time_spine) and date_parts to time spine column.
time_spine_select_columns = []
time_spine_dim_instances = []
where: Optional[SqlExpressionNode] = None
for requested_time_dimension_spec in node.requested_agg_time_dimension_specs:
# Apply granularity to time spine column select expression.
if requested_time_dimension_spec.time_granularity == time_spine_dim_instance.spec.time_granularity:
select_expr: SqlExpressionNode = time_spine_column_select_expr
else:
select_expr = SqlDateTruncExpression(
time_granularity=requested_time_dimension_spec.time_granularity, arg=time_spine_column_select_expr
where_filter: Optional[SqlExpressionNode] = None

# If offset_to_grain is used, will need to filter down to rows that match selected granularities.
# Does not apply if one of the granularities selected matches the time spine column granularity.
need_where_filter = (
node.offset_to_grain
and original_time_spine_dim_instance.spec not in node.requested_agg_time_dimension_specs
)

# Add requested granularities (if different from time_spine) and date_parts to time spine column.
for time_dimension_instance in time_dimensions_to_select_from_time_spine:
time_dimension_spec = time_dimension_instance.spec

# TODO: this will break when we start supporting smaller grain than DAY unless the time spine table is
# updated to use the smallest available grain.
if (
time_dimension_spec.time_granularity.to_int()
< original_time_spine_dim_instance.spec.time_granularity.to_int()
):
raise RuntimeError(
f"Can't join to time spine for a time dimension with a smaller granularity than that of the time "
f"spine column. Got {time_dimension_spec.time_granularity} for time dimension, "
f"{original_time_spine_dim_instance.spec.time_granularity} for time spine."
)
if node.offset_to_grain:
# Filter down to one row per granularity period
new_filter = SqlComparisonExpression(
left_expr=select_expr, comparison=SqlComparison.EQUALS, right_expr=time_spine_column_select_expr
)
if not where:
where = new_filter
else:
where = SqlLogicalExpression(operator=SqlLogicalOperator.OR, args=(where, new_filter))

# Apply grain to time spine select expression, unless grain already matches original time spine column.
select_expr: SqlExpressionNode = (
time_spine_column_select_expr
if time_dimension_spec.time_granularity == original_time_spine_dim_instance.spec.time_granularity
else SqlDateTruncExpression(
time_granularity=time_dimension_spec.time_granularity, arg=time_spine_column_select_expr
)
)
# Filter down to one row per granularity period requested in the group by. Any other granularities
# included here will be filtered out in later nodes so should not be included in where filter.
if need_where_filter and time_dimension_spec in node.requested_agg_time_dimension_specs:
new_where_filter = SqlComparisonExpression(
left_expr=select_expr, comparison=SqlComparison.EQUALS, right_expr=time_spine_column_select_expr
)
where_filter = (
SqlLogicalExpression(operator=SqlLogicalOperator.OR, args=(where_filter, new_where_filter))
if where_filter
else new_where_filter
)

# Apply date_part to time spine column select expression.
if requested_time_dimension_spec.date_part:
select_expr = SqlExtractExpression(date_part=requested_time_dimension_spec.date_part, arg=select_expr)
if time_dimension_spec.date_part:
select_expr = SqlExtractExpression(date_part=time_dimension_spec.date_part, arg=select_expr)
time_dim_spec = TimeDimensionSpec(
element_name=time_spine_dim_instance.spec.element_name,
entity_links=time_spine_dim_instance.spec.entity_links,
time_granularity=requested_time_dimension_spec.time_granularity,
date_part=requested_time_dimension_spec.date_part,
aggregation_state=time_spine_dim_instance.spec.aggregation_state,
element_name=original_time_spine_dim_instance.spec.element_name,
entity_links=original_time_spine_dim_instance.spec.entity_links,
time_granularity=time_dimension_spec.time_granularity,
date_part=time_dimension_spec.date_part,
aggregation_state=original_time_spine_dim_instance.spec.aggregation_state,
)
time_spine_dim_instance = TimeDimensionInstance(
defined_from=time_spine_dim_instance.defined_from,
defined_from=original_time_spine_dim_instance.defined_from,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yikes, this was inside the loop!

associated_columns=(self._column_association_resolver.resolve_spec(time_dim_spec),),
spec=time_dim_spec,
)
Expand All @@ -1383,7 +1420,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
joins_descs=(join_description,),
group_bys=(),
order_bys=(),
where=where,
where=where_filter,
),
)

Expand Down
62 changes: 62 additions & 0 deletions metricflow/test/dataflow/builder/test_dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,3 +1203,65 @@ def test_join_to_time_spine_with_filters(
mf_test_session_state=mf_test_session_state,
dag_graph=dataflow_plan,
)


def test_offset_window_metric_filter_and_query_have_different_granularities(
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
query_parser: MetricFlowQueryParser,
create_source_tables: bool,
) -> None:
"""Test a query where an offset window metric is queried with one granularity and filtered by a different one."""
query_spec = query_parser.parse_and_validate_query(
metric_names=("booking_fees_last_week_per_booker_this_week",),
group_by_names=("metric_time__month",),
where_constraint=PydanticWhereFilter(
where_sql_template=("{{ TimeDimension('metric_time', 'day') }} = '2020-01-01'")
),
)
dataflow_plan = dataflow_plan_builder.build_plan(query_spec)

assert_plan_snapshot_text_equal(
request=request,
mf_test_session_state=mf_test_session_state,
plan=dataflow_plan,
plan_snapshot_text=dataflow_plan.text_structure(),
)

display_graph_if_requested(
request=request,
mf_test_session_state=mf_test_session_state,
dag_graph=dataflow_plan,
)


def test_offset_to_grain_metric_filter_and_query_have_different_granularities(
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
query_parser: MetricFlowQueryParser,
create_source_tables: bool,
) -> None:
"""Test a query where an offset to grain metric is queried with one granularity and filtered by a different one."""
query_spec = query_parser.parse_and_validate_query(
metric_names=("bookings_at_start_of_month",),
group_by_names=("metric_time__month",),
where_constraint=PydanticWhereFilter(
where_sql_template=("{{ TimeDimension('metric_time', 'day') }} = '2020-01-01'")
),
)
dataflow_plan = dataflow_plan_builder.build_plan(query_spec)

assert_plan_snapshot_text_equal(
request=request,
mf_test_session_state=mf_test_session_state,
plan=dataflow_plan,
plan_snapshot_text=dataflow_plan.text_structure(),
)

display_graph_if_requested(
request=request,
mf_test_session_state=mf_test_session_state,
dag_graph=dataflow_plan,
)
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,19 @@ metric:
- name: bookings_offset_once
offset_window: 2 days
---
metric:
name: bookings_at_start_of_month
description: |
Derived metric with offset to grain - single input metric.
Not a particularly useful metric but it allows us to isolate behavior for offset to grain.
type: derived
type_params:
expr: bookings_start_of_month
metrics:
- name: bookings
offset_to_grain: month
alias: bookings_start_of_month
---
metric:
name: booking_fees_since_start_of_month
description: nested derived metric with offset and multiple input metrics
Expand Down
60 changes: 60 additions & 0 deletions metricflow/test/integration/query_output/test_offset_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

import pytest
from _pytest.fixtures import FixtureRequest

from metricflow.engine.metricflow_engine import MetricFlowQueryRequest
from metricflow.protocols.sql_client import SqlClient
from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState
from metricflow.test.integration.conftest import IntegrationTestHelpers
from metricflow.test.snapshot_utils import assert_str_snapshot_equal


@pytest.mark.sql_engine_snapshot
def test_offset_to_grain_with_single_granularity( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
sql_client: SqlClient,
it_helpers: IntegrationTestHelpers,
) -> None:
query_result = it_helpers.mf_engine.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=["bookings_at_start_of_month"],
group_by_names=["metric_time__day"],
order_by_names=["metric_time__day"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yay! Thank you!

)
)
assert query_result.result_df is not None, "Unexpected empty result."

assert_str_snapshot_equal(
request=request,
mf_test_session_state=mf_test_session_state,
snapshot_id="query_output",
snapshot_str=query_result.result_df.to_string(),
sql_engine=sql_client.sql_engine_type,
)


@pytest.mark.sql_engine_snapshot
def test_offset_to_grain_with_multiple_granularities( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
sql_client: SqlClient,
it_helpers: IntegrationTestHelpers,
) -> None:
query_result = it_helpers.mf_engine.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=["bookings_at_start_of_month"],
group_by_names=["metric_time__day", "metric_time__month", "metric_time__year"],
order_by_names=["metric_time__day", "metric_time__month", "metric_time__year"],
)
)
assert query_result.result_df is not None, "Unexpected empty result."

assert_str_snapshot_equal(
request=request,
mf_test_session_state=mf_test_session_state,
snapshot_id="query_output",
snapshot_str=query_result.result_df.to_string(),
sql_engine=sql_client.sql_engine_type,
)
Loading
Loading