Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fill nulls for multi-metric queries #850

Merged
merged 9 commits into from
Nov 8, 2023
Merged
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231106-150014.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Fill nulls for multi-metric queries
time: 2023-11-06T15:00:14.37926-08:00
custom:
Author: courtneyholcomb
Issue: "850"
2 changes: 1 addition & 1 deletion metricflow/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class EntityInstance(MdoInstance[EntitySpec], SemanticModelElementInstance): #
class MetricInstance(MdoInstance[MetricSpec], SerializableDataclass): # noqa: D
associated_columns: Tuple[ColumnAssociation, ...]
spec: MetricSpec
defined_from: Tuple[MetricModelReference, ...]
defined_from: MetricModelReference


@dataclass(frozen=True)
Expand Down
22 changes: 20 additions & 2 deletions metricflow/model/semantics/metric_lookup.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import logging
from typing import Dict, FrozenSet, List, Sequence
from typing import Dict, FrozenSet, List, Optional, Sequence

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilterIntersection
from dbt_semantic_interfaces.implementations.metric import PydanticMetricTimeWindow
from dbt_semantic_interfaces.protocols import WhereFilter
from dbt_semantic_interfaces.protocols.metric import Metric, MetricType
from dbt_semantic_interfaces.protocols.metric import Metric, MetricInputMeasure, MetricType
from dbt_semantic_interfaces.protocols.semantic_manifest import SemanticManifest
from dbt_semantic_interfaces.references import MetricReference

Expand Down Expand Up @@ -105,6 +106,23 @@ def add_metric(self, metric: Metric) -> None:
)
self._metrics[metric_reference] = metric

def configured_input_measure_for_metric(self, metric_reference: MetricReference) -> Optional[MetricInputMeasure]:
"""Get input measure defined in the original metric config, if exists.

When SemanticModel is constructed, input measures from input metrics are added to the list of input measures
for a metric. Here, use rules about metric types to determine which input measures were defined in the config:
- Simple & cumulative metrics require one input measure, and can't take any input metrics.
- Derived & ratio metrics take no input measures, only input metrics.
"""
metric = self.get_metric(metric_reference=metric_reference)
if metric.type is MetricType.CUMULATIVE or metric.type is MetricType.SIMPLE:
assert len(metric.input_measures) == 1, "Simple and cumulative metrics should have one input measure."
return metric.input_measures[0]
elif metric.type is MetricType.RATIO or metric.type is MetricType.DERIVED:
return None
else:
assert_values_exhausted(metric.type)

def measures_for_metric(
self,
metric_reference: MetricReference,
Expand Down
41 changes: 29 additions & 12 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import logging
from collections import OrderedDict
from typing import List, Optional, Sequence, Union
from typing import List, Optional, Sequence, Tuple, Union

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType
from dbt_semantic_interfaces.references import MetricModelReference
from dbt_semantic_interfaces.references import MetricModelReference, MetricReference
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType

from metricflow.aggregation_properties import AggregationState
Expand Down Expand Up @@ -706,7 +706,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet:
metric_instances.append(
MetricInstance(
associated_columns=(output_column_association,),
defined_from=(MetricModelReference(metric_name=metric_spec.element_name),),
defined_from=MetricModelReference(metric_name=metric_spec.element_name),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LOL, thanks. We have so many of these kind of "let's allow a sequence of things and then have exactly one callsite that manually enforces there's only one" interfaces lying around...

spec=metric_spec.alias_spec,
)
)
Expand Down Expand Up @@ -862,24 +862,25 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> SqlDataSet:
),
)

def _make_select_columns_for_metrics(
def _make_select_columns_for_multiple_metrics(
self,
table_alias_to_metric_specs: OrderedDict[str, Sequence[MetricSpec]],
table_alias_to_metric_instances: OrderedDict[str, Tuple[MetricInstance, ...]],
aggregation_type: Optional[AggregationType],
) -> List[SqlSelectColumn]:
"""Creates select columns that get the given metric using the given table alias.

e.g.

with table_alias_to_metric_specs = {"a": MetricSpec(element_name="bookings")}
with table_alias_to_metric_instances = {"a": MetricSpec(element_name="bookings")}

->

a.bookings AS bookings
"""
select_columns = []
for table_alias, metric_specs in table_alias_to_metric_specs.items():
for metric_spec in metric_specs:
for table_alias, metric_instances in table_alias_to_metric_instances.items():
for metric_instance in metric_instances:
metric_spec = metric_instance.spec
metric_column_name = self._column_association_resolver.resolve_spec(metric_spec).column_name
column_reference_expression = SqlColumnReferenceExpression(
col_ref=SqlColumnReference(
Expand All @@ -894,6 +895,21 @@ def _make_select_columns_for_metrics(
else:
select_expression = column_reference_expression

# At this point, the MetricSpec might have the alias in place of the element name, so we need to look
# back at where it was defined from to get the metric element name.
metric_reference = MetricReference(element_name=metric_instance.defined_from.metric_name)
input_measure = self._metric_lookup.configured_input_measure_for_metric(
metric_reference=metric_reference
)
if input_measure and input_measure.fill_nulls_with is not None:
select_expression = SqlAggregateFunctionExpression(
sql_function=SqlFunction.COALESCE,
sql_function_args=[
select_expression,
SqlStringExpression(str(input_measure.fill_nulls_with)),
],
)

select_columns.append(
SqlSelectColumn(
expr=select_expression,
Expand Down Expand Up @@ -938,13 +954,13 @@ def visit_combine_metrics_node(self, node: CombineMetricsNode) -> SqlDataSet:
), "Shouldn't have a CombineMetricsNode in the dataflow plan if there's only 1 parent."

parent_data_sets: List[AnnotatedSqlDataSet] = []
table_alias_to_metric_specs: OrderedDict[str, Sequence[MetricSpec]] = OrderedDict()
table_alias_to_metric_instances: OrderedDict[str, Tuple[MetricInstance, ...]] = OrderedDict()

for parent_node in node.parent_nodes:
parent_sql_data_set = parent_node.accept(self)
table_alias = self._next_unique_table_alias()
parent_data_sets.append(AnnotatedSqlDataSet(data_set=parent_sql_data_set, alias=table_alias))
table_alias_to_metric_specs[table_alias] = parent_sql_data_set.instance_set.spec_set.metric_specs
table_alias_to_metric_instances[table_alias] = parent_sql_data_set.instance_set.metric_instances

# When we create the components of the join that combines metrics it will be one of INNER, FULL OUTER,
# or CROSS JOIN. Order doesn't matter for these join types, so we will use the first element in the FROM
Expand Down Expand Up @@ -986,8 +1002,9 @@ def visit_combine_metrics_node(self, node: CombineMetricsNode) -> SqlDataSet:

metric_aggregation_type = AggregationType.MAX
metric_select_column_set = SelectColumnSet(
metric_columns=self._make_select_columns_for_metrics(
table_alias_to_metric_specs, aggregation_type=metric_aggregation_type
metric_columns=self._make_select_columns_for_multiple_metrics(
table_alias_to_metric_instances=table_alias_to_metric_instances,
aggregation_type=metric_aggregation_type,
)
)
linkable_select_column_set = linkable_spec_set.transform(
Expand Down
2 changes: 1 addition & 1 deletion metricflow/specs/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def from_reference(reference: MetricReference) -> MetricSpec:

@property
def alias_spec(self) -> MetricSpec:
"""Returns a MetricSpec represneting the alias state."""
"""Returns a MetricSpec representing the alias state."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😸

return MetricSpec(
element_name=self.alias or self.element_name,
constraint=self.constraint,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ table_snapshot:
- ["2020-01-02", "2020-01-02", "u1612112", "l2718281"]
- ["2020-01-02", "2020-01-02", "u0004114", ""]
- ["2020-01-02", "2020-01-02", "u0004114", "l7891283-incomplete"]
- ["2020-01-04", "2020-01-02", "u1612112", "l2718281"]
- ["2020-01-05", "2020-01-02", "u0004114", ""]
Empty file.
145 changes: 145 additions & 0 deletions metricflow/test/integration/query_output/test_fill_nulls_with_0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

import datetime

import pytest
from _pytest.fixtures import FixtureRequest

from metricflow.engine.metricflow_engine import MetricFlowQueryRequest
from metricflow.protocols.sql_client import SqlClient
from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState
from metricflow.test.integration.conftest import IntegrationTestHelpers
from metricflow.test.snapshot_utils import assert_object_snapshot_equal


@pytest.mark.sql_engine_snapshot
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These snapshot tests are great, thanks for doing this!

def test_simple_fill_nulls_with_0_metric_time( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
sql_client: SqlClient,
it_helpers: IntegrationTestHelpers,
) -> None:
query_result = it_helpers.mf_engine.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=["bookings_fill_nulls_with_0"],
group_by_names=["metric_time"],
order_by_names=["metric_time"],
time_constraint_start=datetime.datetime(2019, 11, 27),
time_constraint_end=datetime.datetime(2020, 1, 5),
)
)
assert query_result.result_df is not None, "Unexpected empty result."

assert_object_snapshot_equal(
request=request,
mf_test_session_state=mf_test_session_state,
obj_id="query_output",
obj=query_result.result_df.to_string(),
sql_client=sql_client,
)


@pytest.mark.sql_engine_snapshot
def test_simple_fill_nulls_with_0_month( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
sql_client: SqlClient,
it_helpers: IntegrationTestHelpers,
) -> None:
query_result = it_helpers.mf_engine.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=["bookings_fill_nulls_with_0"],
group_by_names=["metric_time__month"],
order_by_names=["metric_time__month"],
time_constraint_start=datetime.datetime(2019, 1, 1),
time_constraint_end=datetime.datetime(2020, 12, 1),
)
)
assert query_result.result_df is not None, "Unexpected empty result."

assert_object_snapshot_equal(
request=request,
mf_test_session_state=mf_test_session_state,
obj_id="query_output",
obj=query_result.result_df.to_string(),
sql_client=sql_client,
)


@pytest.mark.sql_engine_snapshot
def test_simple_join_to_time_spine( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
sql_client: SqlClient,
it_helpers: IntegrationTestHelpers,
) -> None:
query_result = it_helpers.mf_engine.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=["bookings_join_to_time_spine"],
group_by_names=["metric_time"],
time_constraint_start=datetime.datetime(2019, 11, 27),
time_constraint_end=datetime.datetime(2020, 1, 5),
order_by_names=["metric_time"],
)
)
assert query_result.result_df is not None, "Unexpected empty result."

assert_object_snapshot_equal(
request=request,
mf_test_session_state=mf_test_session_state,
obj_id="query_output",
obj=query_result.result_df.to_string(),
sql_client=sql_client,
)


@pytest.mark.sql_engine_snapshot
def test_fill_nulls_with_0_multi_metric_query( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
sql_client: SqlClient,
it_helpers: IntegrationTestHelpers,
) -> None:
query_result = it_helpers.mf_engine.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=["bookings_fill_nulls_with_0", "views"],
group_by_names=["metric_time"],
order_by_names=["metric_time"],
time_constraint_start=datetime.datetime(2019, 11, 27),
time_constraint_end=datetime.datetime(2020, 1, 5),
)
)
assert query_result.result_df is not None, "Unexpected empty result."

assert_object_snapshot_equal(
request=request,
mf_test_session_state=mf_test_session_state,
obj_id="query_output",
obj=query_result.result_df.to_string(),
sql_client=sql_client,
)


@pytest.mark.sql_engine_snapshot
def test_fill_nulls_with_0_multi_metric_query_with_categorical_dimension( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
sql_client: SqlClient,
it_helpers: IntegrationTestHelpers,
) -> None:
query_result = it_helpers.mf_engine.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=["bookings_fill_nulls_with_0_without_time_spine", "views"],
group_by_names=["metric_time", "listing__is_lux_latest"],
order_by_names=["metric_time", "listing__is_lux_latest"],
)
)
assert query_result.result_df is not None, "Unexpected empty result."

assert_object_snapshot_equal(
request=request,
mf_test_session_state=mf_test_session_state,
obj_id="query_output",
obj=query_result.result_df.to_string(),
sql_client=sql_client,
)
Loading