Skip to content

Commit

Permalink
Bug fix: For measures using join_to_timespine, apply filters after …
Browse files Browse the repository at this point in the history
…time spine join (#1056)
  • Loading branch information
courtneyholcomb authored Feb 29, 2024
1 parent 2aa5ea5 commit ad9d192
Show file tree
Hide file tree
Showing 19 changed files with 3,031 additions and 7 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20240228-084335.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: 'Bug fix: if measure joins to time spine, apply filters again after that join.'
time: 2024-02-28T08:43:35.044076-08:00
custom:
Author: courtneyholcomb
Issue: "1039"
22 changes: 15 additions & 7 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def _build_derived_metric_output_node(
)
output_node: BaseOutput = ComputeMetricsNode(parent_node=parent_node, metric_specs=[metric_spec])

# For nested ratio / derived metrics with time offset, apply offset & where constraint after metric computation.
# For ratio / derived metrics with time offset, apply offset & where constraint after metric computation.
if metric_spec.has_time_offset:
queried_agg_time_dimension_specs = queried_linkable_specs.included_agg_time_dimension_specs_for_metric(
metric_reference=metric_spec.reference, metric_lookup=self._metric_lookup
Expand Down Expand Up @@ -1248,6 +1248,7 @@ def _build_aggregated_measure_from_measure_source_node(
measure_spec_properties=measure_properties,
time_range_constraint=(
(cumulative_metric_adjusted_time_constraint or time_range_constraint)
# If joining to time spine for time offset, constraints will be applied after that join.
if not before_aggregation_time_spine_join_description
else None
),
Expand Down Expand Up @@ -1348,10 +1349,9 @@ def _build_aggregated_measure_from_measure_source_node(
)

pre_aggregate_node: BaseOutput = cumulative_metric_constrained_node or unaggregated_measure_node
merged_where_filter_spec = WhereFilterSpec.merge_iterable(metric_input_measure_spec.filter_specs)
if len(metric_input_measure_spec.filter_specs) > 0:
# Apply where constraint on the node

merged_where_filter_spec = WhereFilterSpec.merge_iterable(metric_input_measure_spec.filter_specs)
pre_aggregate_node = WhereConstraintNode(
parent_node=pre_aggregate_node,
where_constraint=merged_where_filter_spec,
Expand Down Expand Up @@ -1395,7 +1395,7 @@ def _build_aggregated_measure_from_measure_source_node(
metric_input_measure_specs=(metric_input_measure_spec,),
)

# Joining to time spine after aggregation is for measures that specify `join_to_timespine`` in the YAML spec.
# Joining to time spine after aggregation is for measures that specify `join_to_timespine` in the YAML spec.
after_aggregation_time_spine_join_description = (
metric_input_measure_spec.after_aggregation_time_spine_join_description
)
Expand All @@ -1404,7 +1404,7 @@ def _build_aggregated_measure_from_measure_source_node(
f"Expected {SqlJoinType.LEFT_OUTER} for joining to time spine after aggregation. Remove this if "
f"there's a new use case."
)
return JoinToTimeSpineNode(
output_node: BaseOutput = JoinToTimeSpineNode(
parent_node=aggregate_measures_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time,
Expand All @@ -1413,5 +1413,13 @@ def _build_aggregated_measure_from_measure_source_node(
offset_window=after_aggregation_time_spine_join_description.offset_window,
offset_to_grain=after_aggregation_time_spine_join_description.offset_to_grain,
)
else:
return aggregate_measures_node
# Since new rows might have been added due to time spine join, apply constraints again here.
if len(metric_input_measure_spec.filter_specs) > 0:
output_node = WhereConstraintNode(parent_node=output_node, where_constraint=merged_where_filter_spec)
if time_range_constraint is not None:
output_node = ConstrainTimeRangeNode(
parent_node=output_node, time_range_constraint=time_range_constraint
)
return output_node

return aggregate_measures_node
34 changes: 34 additions & 0 deletions metricflow/test/dataflow/builder/test_dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
from _pytest.fixtures import FixtureRequest
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter
from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

Expand Down Expand Up @@ -1169,3 +1170,36 @@ def test_min_max_metric_time_week(
mf_test_session_state=mf_test_session_state,
dag_graph=dataflow_plan,
)


def test_join_to_time_spine_with_filters(
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
query_parser: MetricFlowQueryParser,
create_source_tables: bool,
) -> None:
"""Test that filter is not applied until after time spine join."""
query_spec = query_parser.parse_and_validate_query(
metric_names=("bookings_fill_nulls_with_0",),
group_by_names=("metric_time__day",),
where_constraint=PydanticWhereFilter(
where_sql_template=("{{ TimeDimension('metric_time', 'day') }} = '2020-01-01'")
),
time_constraint_start=datetime.datetime(2020, 1, 3),
time_constraint_end=datetime.datetime(2020, 1, 5),
)
dataflow_plan = dataflow_plan_builder.build_plan(query_spec)

assert_plan_snapshot_text_equal(
request=request,
mf_test_session_state=mf_test_session_state,
plan=dataflow_plan,
plan_snapshot_text=dataflow_plan.text_structure(),
)

display_graph_if_requested(
request=request,
mf_test_session_state=mf_test_session_state,
dag_graph=dataflow_plan,
)
36 changes: 36 additions & 0 deletions metricflow/test/query_rendering/test_fill_nulls_with_rendering.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
"""Tests query rendering for coalescing null measures by comparing rendered output against snapshot files."""

from __future__ import annotations

import datetime

import pytest
from _pytest.fixtures import FixtureRequest
from dbt_semantic_interfaces.implementations.filters.where_filter import (
PydanticWhereFilter,
)
from dbt_semantic_interfaces.references import EntityReference
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder
from metricflow.dataset.dataset import DataSet
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.protocols.sql_client import SqlClient
from metricflow.query.query_parser import MetricFlowQueryParser
from metricflow.specs.specs import (
DimensionSpec,
MetricFlowQuerySpec,
Expand Down Expand Up @@ -188,3 +195,32 @@ def test_derived_fill_nulls_for_one_input_metric( # noqa: D
sql_client=sql_client,
node=dataflow_plan.sink_output_nodes[0].parent_node,
)


@pytest.mark.sql_engine_snapshot
def test_join_to_time_spine_with_filters( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
sql_client: SqlClient,
query_parser: MetricFlowQueryParser,
) -> None:
query_spec = query_parser.parse_and_validate_query(
metric_names=("bookings_fill_nulls_with_0",),
group_by_names=("metric_time__day",),
where_constraint=PydanticWhereFilter(
where_sql_template="{{ TimeDimension('metric_time') }} > '2020-01-01'",
),
time_constraint_start=datetime.datetime(2020, 1, 3),
time_constraint_end=datetime.datetime(2020, 1, 5),
)
dataflow_plan = dataflow_plan_builder.build_plan(query_spec)

convert_and_check(
request=request,
mf_test_session_state=mf_test_session_state,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
node=dataflow_plan.sink_output_nodes[0].parent_node,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
<DataflowPlan>
<WriteToResultDataframeNode>
<!-- description = 'Write to Dataframe' -->
<!-- node_id = NodeId(id_str='wrd_0') -->
<ComputeMetricsNode>
<!-- description = 'Compute Metrics via Expressions' -->
<!-- node_id = NodeId(id_str='cm_0') -->
<!-- metric_spec = -->
<!-- MetricSpec( -->
<!-- element_name='bookings_fill_nulls_with_0', -->
<!-- filter_specs=( -->
<!-- WhereFilterSpec( -->
<!-- where_sql="metric_time__day = '2020-01-01'", -->
<!-- bind_parameters=SqlBindParameters(), -->
<!-- linkable_spec_set=LinkableSpecSet( -->
<!-- time_dimension_specs=( -->
<!-- TimeDimensionSpec( -->
<!-- element_name='metric_time', -->
<!-- time_granularity=DAY, -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ) -->
<ConstrainTimeRangeNode>
<!-- description = 'Constrain Time Range to [2020-01-03T00:00:00, 2020-01-05T00:00:00]' -->
<!-- node_id = NodeId(id_str='ctr_1') -->
<!-- time_range_start = '2020-01-03T00:00:00' -->
<!-- time_range_end = '2020-01-05T00:00:00' -->
<WhereConstraintNode>
<!-- description = 'Constrain Output with WHERE' -->
<!-- node_id = NodeId(id_str='wcc_1') -->
<!-- where_condition = -->
<!-- WhereFilterSpec( -->
<!-- where_sql="metric_time__day = '2020-01-01'", -->
<!-- bind_parameters=SqlBindParameters(), -->
<!-- linkable_spec_set=LinkableSpecSet( -->
<!-- time_dimension_specs=( -->
<!-- TimeDimensionSpec( -->
<!-- element_name='metric_time', -->
<!-- time_granularity=DAY, -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ) -->
<JoinToTimeSpineNode>
<!-- description = 'Join to Time Spine Dataset' -->
<!-- node_id = NodeId(id_str='jts_0') -->
<!-- requested_agg_time_dimension_specs = -->
<!-- [TimeDimensionSpec(element_name='metric_time', time_granularity=DAY),] -->
<!-- use_custom_agg_time_dimension = False -->
<!-- time_range_constraint = -->
<!-- TimeRangeConstraint( -->
<!-- start_time=datetime.datetime(2020, 1, 3, 0, 0), -->
<!-- end_time=datetime.datetime(2020, 1, 5, 0, 0), -->
<!-- ) -->
<!-- offset_window = None -->
<!-- offset_to_grain = None -->
<!-- join_type = LEFT_OUTER -->
<AggregateMeasuresNode>
<!-- description = 'Aggregate Measures' -->
<!-- node_id = NodeId(id_str='am_0') -->
<WhereConstraintNode>
<!-- description = 'Constrain Output with WHERE' -->
<!-- node_id = NodeId(id_str='wcc_0') -->
<!-- where_condition = -->
<!-- WhereFilterSpec( -->
<!-- where_sql="metric_time__day = '2020-01-01'", -->
<!-- bind_parameters=SqlBindParameters(), -->
<!-- linkable_spec_set=LinkableSpecSet( -->
<!-- time_dimension_specs=( -->
<!-- TimeDimensionSpec( -->
<!-- element_name='metric_time', -->
<!-- time_granularity=DAY, -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ) -->
<FilterElementsNode>
<!-- description = "Pass Only Elements: ['bookings', 'metric_time__day']" -->
<!-- node_id = NodeId(id_str='pfe_0') -->
<!-- include_spec = MeasureSpec(element_name='bookings') -->
<!-- include_spec = -->
<!-- TimeDimensionSpec(element_name='metric_time', time_granularity=DAY) -->
<!-- distinct = False -->
<ConstrainTimeRangeNode>
<!-- description = -->
<!-- 'Constrain Time Range to [2020-01-03T00:00:00, 2020-01-05T00:00:00]' -->
<!-- node_id = NodeId(id_str='ctr_0') -->
<!-- time_range_start = '2020-01-03T00:00:00' -->
<!-- time_range_end = '2020-01-05T00:00:00' -->
<MetricTimeDimensionTransformNode>
<!-- description = "Metric Time Dimension 'ds'" -->
<!-- node_id = NodeId(id_str='sma_28002') -->
<!-- aggregation_time_dimension = 'ds' -->
<ReadSqlSourceNode>
<!-- description = "Read From SemanticModelDataSet('bookings_source')" -->
<!-- node_id = NodeId(id_str='rss_28014') -->
<!-- data_set = SemanticModelDataSet('bookings_source') -->
</ReadSqlSourceNode>
</MetricTimeDimensionTransformNode>
</ConstrainTimeRangeNode>
</FilterElementsNode>
</WhereConstraintNode>
</AggregateMeasuresNode>
</JoinToTimeSpineNode>
</WhereConstraintNode>
</ConstrainTimeRangeNode>
</ComputeMetricsNode>
</WriteToResultDataframeNode>
</DataflowPlan>
Loading

0 comments on commit ad9d192

Please sign in to comment.