Skip to content

Commit

Permalink
WIP - fix bug in filter for offset metrics with diff granularities
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Feb 23, 2024
1 parent 42656f3 commit 6124a27
Show file tree
Hide file tree
Showing 32 changed files with 1,326 additions and 17 deletions.
28 changes: 21 additions & 7 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
parent_alias = self._next_unique_table_alias()

if node.use_custom_agg_time_dimension:
# TODO: do we need the `requested_agg_time_dimension_specs` property anymore?
agg_time_dimension = node.requested_agg_time_dimension_specs[0]
agg_time_element_name = agg_time_dimension.element_name
agg_time_entity_links: Tuple[EntityReference, ...] = agg_time_dimension.entity_links
Expand Down Expand Up @@ -1300,6 +1301,16 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet

# Select all instances from the parent data set, EXCEPT the requested agg_time_dimension.
# The agg_time_dimension will be selected from the time spine data set.
time_dimensions_to_select_from_parent: Tuple[TimeDimensionInstance, ...] = ()
time_dimensions_to_select_from_time_spine: Tuple[TimeDimensionInstance, ...] = ()
for time_dimension_instance in parent_data_set.instance_set.time_dimension_instances:
if (
time_dimension_instance.spec.element_name == agg_time_element_name
and time_dimension_instance.spec.entity_links == agg_time_entity_links
):
time_dimensions_to_select_from_time_spine += (time_dimension_instance,)
else:
time_dimensions_to_select_from_parent += (time_dimension_instance,)
parent_instance_set = InstanceSet(
measure_instances=parent_data_set.instance_set.measure_instances,
dimension_instances=parent_data_set.instance_set.dimension_instances,
Expand Down Expand Up @@ -1335,14 +1346,17 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
time_spine_select_columns = []
time_spine_dim_instances = []
where: Optional[SqlExpressionNode] = None
for requested_time_dimension_spec in node.requested_agg_time_dimension_specs:
for time_dimension_instance in time_dimensions_to_select_from_time_spine:
time_dimension_spec = time_dimension_instance.spec

# Apply granularity to time spine column select expression.
if requested_time_dimension_spec.time_granularity == time_spine_dim_instance.spec.time_granularity:
if time_dimension_spec.time_granularity == time_spine_dim_instance.spec.time_granularity:
select_expr: SqlExpressionNode = time_spine_column_select_expr
else:
select_expr = SqlDateTruncExpression(
time_granularity=requested_time_dimension_spec.time_granularity, arg=time_spine_column_select_expr
time_granularity=time_dimension_spec.time_granularity, arg=time_spine_column_select_expr
)
# Does this work if multiple are selected?
if node.offset_to_grain:
# Filter down to one row per granularity period
new_filter = SqlComparisonExpression(
Expand All @@ -1353,13 +1367,13 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
else:
where = SqlLogicalExpression(operator=SqlLogicalOperator.OR, args=(where, new_filter))
# Apply date_part to time spine column select expression.
if requested_time_dimension_spec.date_part:
select_expr = SqlExtractExpression(date_part=requested_time_dimension_spec.date_part, arg=select_expr)
if time_dimension_spec.date_part:
select_expr = SqlExtractExpression(date_part=time_dimension_spec.date_part, arg=select_expr)
time_dim_spec = TimeDimensionSpec(
element_name=time_spine_dim_instance.spec.element_name,
entity_links=time_spine_dim_instance.spec.entity_links,
time_granularity=requested_time_dimension_spec.time_granularity,
date_part=requested_time_dimension_spec.date_part,
time_granularity=time_dimension_spec.time_granularity,
date_part=time_dimension_spec.date_part,
aggregation_state=time_spine_dim_instance.spec.aggregation_state,
)
time_spine_dim_instance = TimeDimensionInstance(
Expand Down
33 changes: 33 additions & 0 deletions metricflow/test/dataflow/builder/test_dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
from _pytest.fixtures import FixtureRequest
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter
from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

Expand Down Expand Up @@ -1117,4 +1118,36 @@ def test_min_max_only_time_year(
)


@pytest.mark.sql_engine_snapshot
def test_offset_metric_filter_and_query_have_different_granularities(
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
query_parser: MetricFlowQueryParser,
create_source_tables: bool,
) -> None:
"""Test a query where an offset metrics is queried with one granularity and filtered by a different one."""
query_spec = query_parser.parse_and_validate_query(
metric_names=("booking_fees_last_week_per_booker_this_week",),
group_by_names=("metric_time__month",),
where_constraint=PydanticWhereFilter(
where_sql_template=("{{ TimeDimension('metric_time', 'day') }} = '2020-01-01' ")
),
)
dataflow_plan = dataflow_plan_builder.build_plan(query_spec)

assert_plan_snapshot_text_equal(
request=request,
mf_test_session_state=mf_test_session_state,
plan=dataflow_plan,
plan_snapshot_text=dataflow_plan.text_structure(),
)

display_graph_if_requested(
request=request,
mf_test_session_state=mf_test_session_state,
dag_graph=dataflow_plan,
)


# TODO: add test for min max metric_time (various granularities) when supported
39 changes: 38 additions & 1 deletion metricflow/test/query_rendering/test_derived_metric_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import pytest
from _pytest.fixtures import FixtureRequest
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter
from dbt_semantic_interfaces.implementations.filters.where_filter import (
PydanticWhereFilter,
)
from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME

from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder
Expand Down Expand Up @@ -701,3 +703,38 @@ def test_nested_fill_nulls_without_time_spine_multi_metric( # noqa: D
sql_client=sql_client,
node=dataflow_plan.sink_output_nodes[0].parent_node,
)


@pytest.mark.sql_engine_snapshot
def test_offset_metric_filter_and_query_have_different_granularities(
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
sql_client: SqlClient,
query_parser: MetricFlowQueryParser,
create_source_tables: bool,
) -> None:
"""Test a query where an offset metrics is queried with one granularity and filtered by a different one."""
query_spec = query_parser.parse_and_validate_query(
metric_names=("booking_fees_last_week_per_booker_this_week",),
group_by_names=("metric_time__month",),
where_constraint=PydanticWhereFilter(
where_sql_template=("{{ TimeDimension('metric_time', 'day') }} = '2020-01-01' ")
),
)
dataflow_plan = dataflow_plan_builder.build_plan(query_spec)
print(dataflow_plan)

convert_and_check(
request=request,
mf_test_session_state=mf_test_session_state,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
node=dataflow_plan.sink_output_nodes[0].parent_node,
)


# Dataflow plan is constructed properly
# Issue happens in JoinToTimeSpineNdoe
# we remove that column too soon? is that in DFP or DF to SQL?
Loading

0 comments on commit 6124a27

Please sign in to comment.