Skip to content

Commit

Permalink
Fill nulls for multi-metric queries (#850)
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb authored Nov 8, 2023
1 parent 9db0fe0 commit dc076a1
Show file tree
Hide file tree
Showing 51 changed files with 1,359 additions and 32 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231106-150014.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Fill nulls for multi-metric queries
time: 2023-11-06T15:00:14.37926-08:00
custom:
Author: courtneyholcomb
Issue: "850"
2 changes: 1 addition & 1 deletion metricflow/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class EntityInstance(MdoInstance[EntitySpec], SemanticModelElementInstance): #
class MetricInstance(MdoInstance[MetricSpec], SerializableDataclass): # noqa: D
associated_columns: Tuple[ColumnAssociation, ...]
spec: MetricSpec
defined_from: Tuple[MetricModelReference, ...]
defined_from: MetricModelReference


@dataclass(frozen=True)
Expand Down
22 changes: 20 additions & 2 deletions metricflow/model/semantics/metric_lookup.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import logging
from typing import Dict, FrozenSet, List, Sequence
from typing import Dict, FrozenSet, List, Optional, Sequence

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilterIntersection
from dbt_semantic_interfaces.implementations.metric import PydanticMetricTimeWindow
from dbt_semantic_interfaces.protocols import WhereFilter
from dbt_semantic_interfaces.protocols.metric import Metric, MetricType
from dbt_semantic_interfaces.protocols.metric import Metric, MetricInputMeasure, MetricType
from dbt_semantic_interfaces.protocols.semantic_manifest import SemanticManifest
from dbt_semantic_interfaces.references import MetricReference

Expand Down Expand Up @@ -105,6 +106,23 @@ def add_metric(self, metric: Metric) -> None:
)
self._metrics[metric_reference] = metric

def configured_input_measure_for_metric(self, metric_reference: MetricReference) -> Optional[MetricInputMeasure]:
"""Get input measure defined in the original metric config, if exists.
When SemanticModel is constructed, input measures from input metrics are added to the list of input measures
for a metric. Here, use rules about metric types to determine which input measures were defined in the config:
- Simple & cumulative metrics require one input measure, and can't take any input metrics.
- Derived & ratio metrics take no input measures, only input metrics.
"""
metric = self.get_metric(metric_reference=metric_reference)
if metric.type is MetricType.CUMULATIVE or metric.type is MetricType.SIMPLE:
assert len(metric.input_measures) == 1, "Simple and cumulative metrics should have one input measure."
return metric.input_measures[0]
elif metric.type is MetricType.RATIO or metric.type is MetricType.DERIVED:
return None
else:
assert_values_exhausted(metric.type)

def measures_for_metric(
self,
metric_reference: MetricReference,
Expand Down
41 changes: 29 additions & 12 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import logging
from collections import OrderedDict
from typing import List, Optional, Sequence, Union
from typing import List, Optional, Sequence, Tuple, Union

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType
from dbt_semantic_interfaces.references import MetricModelReference
from dbt_semantic_interfaces.references import MetricModelReference, MetricReference
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType

from metricflow.aggregation_properties import AggregationState
Expand Down Expand Up @@ -628,7 +628,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet:
metric_instances.append(
MetricInstance(
associated_columns=(output_column_association,),
defined_from=(MetricModelReference(metric_name=metric_spec.element_name),),
defined_from=MetricModelReference(metric_name=metric_spec.element_name),
spec=metric_spec.alias_spec,
)
)
Expand Down Expand Up @@ -784,24 +784,25 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> SqlDataSet:
),
)

def _make_select_columns_for_metrics(
def _make_select_columns_for_multiple_metrics(
self,
table_alias_to_metric_specs: OrderedDict[str, Sequence[MetricSpec]],
table_alias_to_metric_instances: OrderedDict[str, Tuple[MetricInstance, ...]],
aggregation_type: Optional[AggregationType],
) -> List[SqlSelectColumn]:
"""Creates select columns that get the given metric using the given table alias.
e.g.
with table_alias_to_metric_specs = {"a": MetricSpec(element_name="bookings")}
with table_alias_to_metric_instances = {"a": MetricSpec(element_name="bookings")}
->
a.bookings AS bookings
"""
select_columns = []
for table_alias, metric_specs in table_alias_to_metric_specs.items():
for metric_spec in metric_specs:
for table_alias, metric_instances in table_alias_to_metric_instances.items():
for metric_instance in metric_instances:
metric_spec = metric_instance.spec
metric_column_name = self._column_association_resolver.resolve_spec(metric_spec).column_name
column_reference_expression = SqlColumnReferenceExpression(
col_ref=SqlColumnReference(
Expand All @@ -816,6 +817,21 @@ def _make_select_columns_for_metrics(
else:
select_expression = column_reference_expression

# At this point, the MetricSpec might have the alias in place of the element name, so we need to look
# back at where it was defined from to get the metric element name.
metric_reference = MetricReference(element_name=metric_instance.defined_from.metric_name)
input_measure = self._metric_lookup.configured_input_measure_for_metric(
metric_reference=metric_reference
)
if input_measure and input_measure.fill_nulls_with is not None:
select_expression = SqlAggregateFunctionExpression(
sql_function=SqlFunction.COALESCE,
sql_function_args=[
select_expression,
SqlStringExpression(str(input_measure.fill_nulls_with)),
],
)

select_columns.append(
SqlSelectColumn(
expr=select_expression,
Expand Down Expand Up @@ -860,13 +876,13 @@ def visit_combine_metrics_node(self, node: CombineMetricsNode) -> SqlDataSet:
), "Shouldn't have a CombineMetricsNode in the dataflow plan if there's only 1 parent."

parent_data_sets: List[AnnotatedSqlDataSet] = []
table_alias_to_metric_specs: OrderedDict[str, Sequence[MetricSpec]] = OrderedDict()
table_alias_to_metric_instances: OrderedDict[str, Tuple[MetricInstance, ...]] = OrderedDict()

for parent_node in node.parent_nodes:
parent_sql_data_set = parent_node.accept(self)
table_alias = self._next_unique_table_alias()
parent_data_sets.append(AnnotatedSqlDataSet(data_set=parent_sql_data_set, alias=table_alias))
table_alias_to_metric_specs[table_alias] = parent_sql_data_set.instance_set.spec_set.metric_specs
table_alias_to_metric_instances[table_alias] = parent_sql_data_set.instance_set.metric_instances

# When we create the components of the join that combines metrics it will be one of INNER, FULL OUTER,
# or CROSS JOIN. Order doesn't matter for these join types, so we will use the first element in the FROM
Expand Down Expand Up @@ -908,8 +924,9 @@ def visit_combine_metrics_node(self, node: CombineMetricsNode) -> SqlDataSet:

metric_aggregation_type = AggregationType.MAX
metric_select_column_set = SelectColumnSet(
metric_columns=self._make_select_columns_for_metrics(
table_alias_to_metric_specs, aggregation_type=metric_aggregation_type
metric_columns=self._make_select_columns_for_multiple_metrics(
table_alias_to_metric_instances=table_alias_to_metric_instances,
aggregation_type=metric_aggregation_type,
)
)
linkable_select_column_set = linkable_spec_set.transform(
Expand Down
2 changes: 1 addition & 1 deletion metricflow/specs/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def from_reference(reference: MetricReference) -> MetricSpec:

@property
def alias_spec(self) -> MetricSpec:
"""Returns a MetricSpec represneting the alias state."""
"""Returns a MetricSpec representing the alias state."""
return MetricSpec(
element_name=self.alias or self.element_name,
constraint=self.constraint,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ table_snapshot:
- ["2020-01-02", "2020-01-02", "u1612112", "l2718281"]
- ["2020-01-02", "2020-01-02", "u0004114", ""]
- ["2020-01-02", "2020-01-02", "u0004114", "l7891283-incomplete"]
- ["2020-01-04", "2020-01-02", "u1612112", "l2718281"]
- ["2020-01-05", "2020-01-02", "u0004114", ""]
Empty file.
145 changes: 145 additions & 0 deletions metricflow/test/integration/query_output/test_fill_nulls_with_0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

import datetime

import pytest
from _pytest.fixtures import FixtureRequest

from metricflow.engine.metricflow_engine import MetricFlowQueryRequest
from metricflow.protocols.sql_client import SqlClient
from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState
from metricflow.test.integration.conftest import IntegrationTestHelpers
from metricflow.test.snapshot_utils import assert_object_snapshot_equal


@pytest.mark.sql_engine_snapshot
def test_simple_fill_nulls_with_0_metric_time( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
sql_client: SqlClient,
it_helpers: IntegrationTestHelpers,
) -> None:
query_result = it_helpers.mf_engine.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=["bookings_fill_nulls_with_0"],
group_by_names=["metric_time"],
order_by_names=["metric_time"],
time_constraint_start=datetime.datetime(2019, 11, 27),
time_constraint_end=datetime.datetime(2020, 1, 5),
)
)
assert query_result.result_df is not None, "Unexpected empty result."

assert_object_snapshot_equal(
request=request,
mf_test_session_state=mf_test_session_state,
obj_id="query_output",
obj=query_result.result_df.to_string(),
sql_client=sql_client,
)


@pytest.mark.sql_engine_snapshot
def test_simple_fill_nulls_with_0_month( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
sql_client: SqlClient,
it_helpers: IntegrationTestHelpers,
) -> None:
query_result = it_helpers.mf_engine.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=["bookings_fill_nulls_with_0"],
group_by_names=["metric_time__month"],
order_by_names=["metric_time__month"],
time_constraint_start=datetime.datetime(2019, 1, 1),
time_constraint_end=datetime.datetime(2020, 12, 1),
)
)
assert query_result.result_df is not None, "Unexpected empty result."

assert_object_snapshot_equal(
request=request,
mf_test_session_state=mf_test_session_state,
obj_id="query_output",
obj=query_result.result_df.to_string(),
sql_client=sql_client,
)


@pytest.mark.sql_engine_snapshot
def test_simple_join_to_time_spine( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
sql_client: SqlClient,
it_helpers: IntegrationTestHelpers,
) -> None:
query_result = it_helpers.mf_engine.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=["bookings_join_to_time_spine"],
group_by_names=["metric_time"],
time_constraint_start=datetime.datetime(2019, 11, 27),
time_constraint_end=datetime.datetime(2020, 1, 5),
order_by_names=["metric_time"],
)
)
assert query_result.result_df is not None, "Unexpected empty result."

assert_object_snapshot_equal(
request=request,
mf_test_session_state=mf_test_session_state,
obj_id="query_output",
obj=query_result.result_df.to_string(),
sql_client=sql_client,
)


@pytest.mark.sql_engine_snapshot
def test_fill_nulls_with_0_multi_metric_query( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
sql_client: SqlClient,
it_helpers: IntegrationTestHelpers,
) -> None:
query_result = it_helpers.mf_engine.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=["bookings_fill_nulls_with_0", "views"],
group_by_names=["metric_time"],
order_by_names=["metric_time"],
time_constraint_start=datetime.datetime(2019, 11, 27),
time_constraint_end=datetime.datetime(2020, 1, 5),
)
)
assert query_result.result_df is not None, "Unexpected empty result."

assert_object_snapshot_equal(
request=request,
mf_test_session_state=mf_test_session_state,
obj_id="query_output",
obj=query_result.result_df.to_string(),
sql_client=sql_client,
)


@pytest.mark.sql_engine_snapshot
def test_fill_nulls_with_0_multi_metric_query_with_categorical_dimension( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
sql_client: SqlClient,
it_helpers: IntegrationTestHelpers,
) -> None:
query_result = it_helpers.mf_engine.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=["bookings_fill_nulls_with_0_without_time_spine", "views"],
group_by_names=["metric_time", "listing__is_lux_latest"],
order_by_names=["metric_time", "listing__is_lux_latest"],
)
)
assert query_result.result_df is not None, "Unexpected empty result."

assert_object_snapshot_equal(
request=request,
mf_test_session_state=mf_test_session_state,
obj_id="query_output",
obj=query_result.result_df.to_string(),
sql_client=sql_client,
)
Loading

0 comments on commit dc076a1

Please sign in to comment.